unbind_variables

unbind_variables#

penzai.core.variables.unbind_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None, freeze: Literal[False] = False) tuple[Any, tuple[AbstractVariable, ...]][source]#
penzai.core.variables.unbind_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None, *, freeze: Literal[True]) tuple[Any, tuple[AbstractVariableValue, ...]]

Unbinds variables from a pytree, inserting variable slots in their place.

This function can be used to extract variables from a pytree before a JAX transformation or control-flow primitive. Those vars can either be directly updated, or passed through a JAX transformation by calling freeze on them, calling unfreeze_as_copy and then bind_variables inside the transformation, and then calling freeze again before returning the updated variable values.

Parameters:
  • tree – A tree containing variables. Each variable can appear in the tree more than once, but if there are two distinct variable objects that have the same label (or otherwise would map to the same slot), an error will be raised.

  • predicate – A function that returns True for variables that should be extracted. If None, all variables will be extracted.

  • freeze – Whether to return frozen variables instead of mutable variables.

Returns:

A tuple (tree_with_slots, variables), where tree_with_slots is a copy of the original tree with the extracted variables replaced by their corresponding slots, and variables is a collection of variables extracted. If frozen == True, the returned variables will be frozen.

Raises:

VariableConflictError – If two variables map to the same slot but are different Python objects, or if there is already a conflicting slot in the tree.