hoist_shared_state_requests#
- penzai.deprecated.v1.data_effects.local_state.hoist_shared_state_requests(tree: Any, unsafe: bool = False) → tuple[Any, dict[str, InitialLocalStateRequest | FrozenLocalStateRequest]][source]#
Hoists out the value of shared states in a pytree.
This function is a helper for manipulating Penzai models that contain shared states. Ordinarily, shared states in a Penzai model are represented as some combination of:
exactly one of:
a single
FrozenLocalStateRequest
with a valuemultiple
InitialLocalStateRequest
nodes with identical initializers
followed by one or more
SharedLocalStateRequest
nodes
where all such requests have the same explicit name. This is convenient for manipulating the model as a whole, but can make it somewhat annoying to extract a small part of a model that uses a shared state defined elsewhere.
This function takes a tree of this form and returns a new tree that only contains a
SharedLocalStateRequest
whenever there is a state that is used in multiple places, along with a dictionary mapping each state name to the concrete definition it uses. The new tree can be freely manipulated, and then a single copy of the state can be re-embedded usingembed_shared_states
.- Parameters:
tree – A tree where each shared variable appears either once as a
FrozenLocalStateRequest
or multiple times as identicalInitialLocalStateRequest
nodes, followed by some number ofSharedLocalStateRequest
nodes.unsafe – If True,
tree
can have multipleFrozenLocalStateRequest
orInitialLocalStateRequest
nodes with different initializers, and one will be picked arbitrarily.
- Returns:
A tuple of
(new_tree, state_defs)
, wherenew_tree
is a copy oftree
where all shared states have been replaced by aSharedLocalStateRequest
, andstate_defs
is a dictionary mapping each shared state name to the correspondingInitialLocalStateRequest
orFrozenLocalStateRequest
. These can be passed toembed_shared_state_requests
to rebuild the original tree.