InitialLocalStateRequest#

class penzai.data_effects.local_state.InitialLocalStateRequest[source]#

Bases: Generic[_T], EffectRequest, SupportsParameterRenaming

Effect request for local state, with a state initializer.

This can be used to configure the initial state when initializing a model or transforming a model into a stateful configuration.

Typically, if this state is something that should be updated and checkpointed during training, each InitialLocalStateRequest should be created when the model is built, and given a name similar to a parameter. If this state is something that will only be used temporarily (e.g. decoding state while sampling or doing a per-example rollout), it’s not necessary to give it a name.

Variables:
  • state_initializer (Callable[[], _T]) – Callable that builds the initial state.

  • category (Category) – Category tag identifying the kind of state.

  • name (str | None) – Optional name for this state. If provided, it will be renamed as if this state was a parameter, and used as the key for this state in the state dictionary. If not provided, a name will be inferred from the PyTree structure at the time that the state is used. States with the same explicit name will share the same value.

Methods

__init__(state_initializer, category[, name])

as_frozen()

Initializes the parameter, returning an equivalent frozen one.

effect_protocol()

with_renamed_parameters(rename_fn)

Attributes

name

state_initializer

category

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

as_frozen() FrozenLocalStateRequest[_T][source]#

Initializes the parameter, returning an equivalent frozen one.