LayerStackGetAttrKey#
- class penzai.nn.layer_stack.LayerStackGetAttrKey[source]#
Bases:
CustomGetAttrKeyGetAttrKey for LayerStack with extra metadata.
This allows us to identify whether a given PyTree leaf is contained inside a LayerStack, and if so, which axis it is stacked along. This can in turn be used to manipulate variables inside a LayerStack in a stack-compatible way.
Methods
__init__(name, stack_axis, stack_axis_size)Attributes
namestack_axisstack_axis_size