WithFunctionalLocalState

WithFunctionalLocalState#

class penzai.data_effects.local_state.WithFunctionalLocalState[source]#

Bases: EffectHandler

LocalState effect handler that functionalizes local states.

WithFunctionalLocalState transforms the body layer so that it takes a dictionary of states as an argument and returns a dictionary of states as a result.

The standard way to construct a WithFunctionalLocalState handler is to use handle_local_states, which returns a functional wrapper and also the initial state callable. Conversely, you can re-embed local states into the model using freeze_local_states.

Methods

__init__(handler_id, body)

effect_protocol()

input_structure()

output_structure()

__call__(argument)

Attributes

handler_id

body

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()