unbind_variables#
- penzai.core.variables.unbind_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None, freeze: Literal[False] = False) tuple[Any, tuple[AbstractVariable, ...]] [source]#
- penzai.core.variables.unbind_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None, *, freeze: Literal[True]) tuple[Any, tuple[AbstractVariableValue, ...]]
Unbinds variables from a pytree, inserting variable slots in their place.
This function can be used to extract variables from a pytree before a JAX transformation or control-flow primitive. Those vars can either be directly updated, or passed through a JAX transformation by calling
freeze
on them, callingunfreeze_as_copy
and thenbind_variables
inside the transformation, and then callingfreeze
again before returning the updated variable values.- Parameters:
tree – A tree containing variables. Each variable can appear in the tree more than once, but if there are two distinct variable objects that have the same label (or otherwise would map to the same slot), an error will be raised.
predicate – A function that returns True for variables that should be extracted. If None, all variables will be extracted.
freeze – Whether to return frozen variables instead of mutable variables.
- Returns:
A tuple
(tree_with_slots, variables)
, wheretree_with_slots
is a copy of the original tree with the extracted variables replaced by their corresponding slots, andvariables
is a collection of variables extracted. Iffrozen == True
, the returned variables will be frozen.- Raises:
VariableConflictError – If two variables map to the same slot but are different Python objects, or if there is already a conflicting slot in the tree.