Local layer state effect.

This effect allows layers to maintain and update their own state variables. The local state effect is a complicated effect relative to others, since state is used in different ways by different models.

Penzai’s local state effect is designed to support two type of state:

  • explicitly-named state variables, which act like parameters but are updated and checkpointed separately from parameters, and aren’t updated by gradient descent. This could include e.g. batch norm statistics.

  • unnamed state variables, which are used for things like sampling state or decoding state, which are locally updated in some context but are rarely checkpointed or serialized.

Both types of state use the same basic mechanism. The difference is that explicitly-named state variables will be renamed by parameter-renaming transformations (but not be affected by being moved around in the model PyTree), whereas unnamed state variables will have an inferred name based on their position in the PyTree at the moment that the state was handled (but will not be affected by parameter naming).

In general, stateful models should introduce state by adding an attribute with type LocalStateEffect and initial value that is an instance of InitialLocalStateRequest. Stateful models can then be turned into a stateful form using handle_local_states, which also produces an initial state dictionary. The state dictionary can then be combined with the model again using freeze_local_states, which embeds the current state variables as FrozenLocalStateRequest instances; the model can then be checkpointed if desired.



Effect request for local state with a frozen value.


Marker for a handled local state effect.


Effect request for local state, with a state initializer.


Protocol for a local state effect.


Implementation of the local state effect.


Effect request for local state that is shared.


LocalState effect handler that functionalizes local states.


embed_shared_state_requests(tree, state_requests)

Embeds shared state requests into a tree.

freeze_local_states(handled, states)

Embeds the given states into a handled model, and removes the handler.


Extracts local states from a stateful model.

hoist_shared_state_requests(tree[, unsafe])

Hoists out the value of shared states in a pytree.