hoist_shared_state_requests

hoist_shared_state_requests#

penzai.data_effects.local_state.hoist_shared_state_requests(tree: Any, unsafe: bool = False) tuple[Any, dict[str, InitialLocalStateRequest | FrozenLocalStateRequest]][source]#

Hoists out the value of shared states in a pytree.

This function is a helper for manipulating Penzai models that contain shared states. Ordinarily, shared states in a Penzai model are represented as some combination of:

where all such requests have the same explicit name. This is convenient for manipulating the model as a whole, but can make it somewhat annoying to extract a small part of a model that uses a shared state defined elsewhere.

This function takes a tree of this form and returns a new tree that only contains a SharedLocalStateRequest whenever there is a state that is used in multiple places, along with a dictionary mapping each state name to the concrete definition it uses. The new tree can be freely manipulated, and then a single copy of the state can be re-embedded using embed_shared_states.

Parameters:
Returns:

A tuple of (new_tree, state_defs), where new_tree is a copy of tree where all shared states have been replaced by a SharedLocalStateRequest, and state_defs is a dictionary mapping each shared state name to the corresponding InitialLocalStateRequest or FrozenLocalStateRequest. These can be passed to embed_shared_state_requests to rebuild the original tree.