layer_stack#
Layer stacks.
Classes
A sequence of layers with identical structure, called under jax.lax.scan. |
|
GetAttrKey for LayerStack with extra metadata. |
|
Behavior of a variable in a layer stack. |
Functions
|
Extracts the stacked axes from a keypath. |