LayerStackGetAttrKey

LayerStackGetAttrKey#

class penzai.nn.layer_stack.LayerStackGetAttrKey[source]#

Bases: GetAttrKey

GetAttrKey for LayerStack with extra metadata.

This allows us to identify whether a given PyTree leaf is contained inside a LayerStack, and if so, which axis it is stacked along. This can in turn be used to manipulate variables inside a LayerStack in a stack-compatible way.

Methods

__init__(name, stack_axis, stack_axis_size)

Attributes

name

stack_axis

stack_axis_size