LayerStack#

class penzai.nn.layer_stack.LayerStack[source]#

Bases: Layer

A sequence of layers with identical structure, called under jax.lax.scan.

This class identifies a sequence of layers that can be efficiently called using a jax.lax.scan control flow primitive, speeding up compilation times under jax.jit. Instead of storing separate copies of each layer, this class instead stores a single “prototype” layer, whose leaves have an additional named axis (the stack_axis) whenever they differ across layers.

StateVariable instances inside the stacked sublayer are required to have a metadata field “layerstack_axes”, which maps each axis name to a LayerStackVarBehavior determining whether it should be shared or split across layers. (This is not necessary for Parameters, which are not mutable when the layer is called.)

Variables:
  • stacked_sublayers (layer_base.Layer) – A collection of sublayers, each of which have an extra named axis.

  • stack_axis (named_axes.AxisName) – The axis name that layer data is stacked along.

  • stack_axis_size (int) – The size of the stack axis.

Methods

__init__(stacked_sublayers, stack_axis, ...)

from_sublayer_builder(builder, stack_axis, ...)

Builds a layer stack of layers with non-shared parameters.

key_for_field(field_name)

Returns a custom GetAttrKey with layer stack metadata.

__call__(argument, /, **side_inputs)

Calls the stacked sublayers under a jax.lax.scan.

Attributes

stacked_sublayers

stack_axis

stack_axis_size

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(argument, /, **side_inputs) Any[source]#

Calls the stacked sublayers under a jax.lax.scan.

classmethod from_sublayer_builder(builder: Callable[..., layer_base.Layer], stack_axis: named_axes.AxisName, stack_axis_size: int, init_base_rng: jax.Array | None, builder_kwargs: dict[str, Any]) LayerStack[source]#

Builds a layer stack of layers with non-shared parameters.

This function assumes that all variables returned by this builder are defined inside the builder. Returning variables that were already defined outside the builder is not supported.

Parameters:
  • builder – A function that builds a single layer, which must take a keyword argument init_base_rng. All variables, as well as all other leaf values that depend on this RNG, must be NamedArrays.

  • stack_axis – The axis name that layer data is stacked along.

  • stack_axis_size – The size of the stack axis.

  • init_base_rng – The base RNG for initializing the parameters.

  • builder_kwargs – Keyword arguments to pass to the builder.

Returns:

A new layer stack. All arrays and variables will be split across the stack axis.

key_for_field(field_name: str) Hashable[source]#

Returns a custom GetAttrKey with layer stack metadata.