LayerStack#
- class penzai.nn.layer_stack.LayerStack[source]#
Bases:
Layer
A sequence of layers with identical structure, called under jax.lax.scan.
This class identifies a sequence of layers that can be efficiently called using a
jax.lax.scan
control flow primitive, speeding up compilation times underjax.jit
. Instead of storing separate copies of each layer, this class instead stores a single “prototype” layer, whose leaves have an additional named axis (thestack_axis
) whenever they differ across layers.StateVariable instances inside the stacked sublayer are required to have a metadata field “layerstack_axes”, which maps each axis name to a
LayerStackVarBehavior
determining whether it should be shared or split across layers. (This is not necessary for Parameters, which are not mutable when the layer is called.)- Variables:
stacked_sublayers (layer_base.Layer) – A collection of sublayers, each of which have an extra named axis.
stack_axis (named_axes.AxisName) – The axis name that layer data is stacked along.
stack_axis_size (int) – The size of the stack axis.
Methods
__init__
(stacked_sublayers, stack_axis, ...)from_sublayer_builder
(builder, stack_axis, ...)Builds a layer stack of layers with non-shared parameters.
key_for_field
(field_name)Returns a custom GetAttrKey with layer stack metadata.
__call__
(argument, /, **side_inputs)Calls the stacked sublayers under a
jax.lax.scan
.Attributes
stacked_sublayers
stack_axis
stack_axis_size
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(argument, /, **side_inputs) Any [source]#
Calls the stacked sublayers under a
jax.lax.scan
.
- classmethod from_sublayer_builder(builder: Callable[..., layer_base.Layer], stack_axis: named_axes.AxisName, stack_axis_size: int, init_base_rng: jax.Array | None, builder_kwargs: dict[str, Any]) LayerStack [source]#
Builds a layer stack of layers with non-shared parameters.
This function assumes that all variables returned by this builder are defined inside the builder. Returning variables that were already defined outside the builder is not supported.
- Parameters:
builder – A function that builds a single layer, which must take a keyword argument
init_base_rng
. All variables, as well as all other leaf values that depend on this RNG, must be NamedArrays.stack_axis – The axis name that layer data is stacked along.
stack_axis_size – The size of the stack axis.
init_base_rng – The base RNG for initializing the parameters.
builder_kwargs – Keyword arguments to pass to the builder.
- Returns:
A new layer stack. All arrays and variables will be split across the stack axis.