layerstack_axes_from_keypath#
- penzai.nn.layer_stack.layerstack_axes_from_keypath(keypath: tuple[Any, ...]) dict[named_axes.AxisName, int] [source]#
Extracts the stacked axes from a keypath.
This can be used to initialize new variables for transformations that modify layers inside a LayerStack. Generally, if this function returns a non-empty dict for a given keypath, then any new variable added should include these axis names in its “layerstack_axes” metadata, and (if it is PER_LAYER) should include the axis names and sizes in its values.
- Parameters:
keypath – A JAX keypath to a subtree of a PyTree.
- Returns:
A mapping containing the names and sizes of any axes mapped over by LayerStack layers.