unbind_variables#
- penzai.core.variables.unbind_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None, freeze: Literal[False] = False) tuple[Any, tuple[AbstractVariable, ...]][source]#
- penzai.core.variables.unbind_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None, *, freeze: Literal[True]) tuple[Any, tuple[AbstractVariableValue, ...]]
Unbinds variables from a pytree, inserting variable slots in their place.
This function can be used to extract variables from a pytree before a JAX transformation or control-flow primitive. Those vars can either be directly updated, or passed through a JAX transformation by calling
freezeon them, callingunfreeze_as_copyand thenbind_variablesinside the transformation, and then callingfreezeagain before returning the updated variable values.- Parameters:
tree – A tree containing variables. Each variable can appear in the tree more than once, but if there are two distinct variable objects that have the same label (or otherwise would map to the same slot), an error will be raised.
predicate – A function that returns True for variables that should be extracted. If None, all variables will be extracted.
freeze – Whether to return frozen variables instead of mutable variables.
- Returns:
A tuple
(tree_with_slots, variables), wheretree_with_slotsis a copy of the original tree with the extracted variables replaced by their corresponding slots, andvariablesis a collection of variables extracted. Iffrozen == True, the returned variables will be frozen.- Raises:
VariableConflictError – If two variables map to the same slot but are different Python objects, or if there is already a conflicting slot in the tree.