SharedLocalStateRequest

SharedLocalStateRequest#

class penzai.data_effects.local_state.SharedLocalStateRequest[source]#

Bases: Generic[_T], EffectRequest

Effect request for local state that is shared.

A SharedLocalStateRequest can be used to share an explicitly-named state variable between multiple layers. Shared states should have exactly one version that is a FrozenLocalStateRequest or InitialLocalStateRequest, and potentially multiple versions that are each SharedLocalStateRequest instances. Furthermore, the SharedLocalStateRequest must appear after the version with a value in PyTree flattening order.

Variables:
  • name (str) – Name for this state, which will be renamed as if this state was a parameter, and used as the key for this state in the state dictionary.

  • category (Category) – Category tag identifying the kind of state.

Methods

__init__(name, category)

effect_protocol()

with_renamed_parameters(rename_fn)

Attributes

name

category

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()