LayerStack#
- class penzai.nn.layer_stack.LayerStack[source]#
Bases:
LayerA sequence of layers with identical structure, called under jax.lax.scan.
This class identifies a sequence of layers that can be efficiently called using a
jax.lax.scancontrol flow primitive, speeding up compilation times underjax.jit. Instead of storing separate copies of each layer, this class instead stores a single “prototype” layer, whose leaves have an additional named axis (thestack_axis) whenever they differ across layers.StateVariable instances inside the stacked sublayer are required to have a metadata field “layerstack_axes”, which maps each axis name to a
LayerStackVarBehaviordetermining whether it should be shared or split across layers. (This is not necessary for Parameters, which are not mutable when the layer is called.)- Variables:
stacked_sublayers (layer_base.Layer) – A collection of sublayers, each of which have an extra named axis.
stack_axis (named_axes.AxisName) – The axis name that layer data is stacked along.
stack_axis_size (int) – The size of the stack axis.
Methods
__init__(stacked_sublayers, stack_axis, ...)from_sublayer_builder(builder, stack_axis, ...)Builds a layer stack of layers with non-shared parameters.
key_for_field(field_name)Returns a custom GetAttrKey with layer stack metadata.
__call__(argument, /, **side_inputs)Calls the stacked sublayers under a
jax.lax.scan.Attributes
stacked_sublayersstack_axisstack_axis_sizeInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
bind_variables(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
select()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(argument, /, **side_inputs) Any[source]#
Calls the stacked sublayers under a
jax.lax.scan.
- classmethod from_sublayer_builder(builder: Callable[..., layer_base.Layer], stack_axis: named_axes.AxisName, stack_axis_size: int, init_base_rng: jax.Array | None, builder_kwargs: dict[str, Any]) LayerStack[source]#
Builds a layer stack of layers with non-shared parameters.
This function assumes that all variables returned by this builder are defined inside the builder. Returning variables that were already defined outside the builder is not supported.
- Parameters:
builder – A function that builds a single layer, which must take a keyword argument
init_base_rng. All variables, as well as all other leaf values that depend on this RNG, must be NamedArrays.stack_axis – The axis name that layer data is stacked along.
stack_axis_size – The size of the stack axis.
init_base_rng – The base RNG for initializing the parameters.
builder_kwargs – Keyword arguments to pass to the builder.
- Returns:
A new layer stack. All arrays and variables will be split across the stack axis.