LayerStackGetAttrKey#
- class penzai.nn.layer_stack.LayerStackGetAttrKey[source]#
Bases:
GetAttrKey
GetAttrKey for LayerStack with extra metadata.
This allows us to identify whether a given PyTree leaf is contained inside a LayerStack, and if so, which axis it is stacked along. This can in turn be used to manipulate variables inside a LayerStack in a stack-compatible way.
Methods
__init__
(name, stack_axis, stack_axis_size)Attributes
name
stack_axis
stack_axis_size