layerstack_axes_from_keypath

layerstack_axes_from_keypath#

penzai.nn.layer_stack.layerstack_axes_from_keypath(keypath: tuple[Any, ...]) dict[named_axes.AxisName, int][source]#

Extracts the stacked axes from a keypath.

This can be used to initialize new variables for transformations that modify layers inside a LayerStack. Generally, if this function returns a non-empty dict for a given keypath, then any new variable added should include these axis names in its “layerstack_axes” metadata, and (if it is PER_LAYER) should include the axis names and sizes in its values.

Parameters:

keypath – A JAX keypath to a subtree of a PyTree.

Returns:

A mapping containing the names and sizes of any axes mapped over by LayerStack layers.