freeze_variables

freeze_variables#

penzai.core.variables.freeze_variables(tree: Any, predicate: Callable[[AbstractVariable], bool] | None = None) Any[source]#

Replaces each variable in a pytree with a frozen copy.

The resulting tree will contain frozen variable instances instead of mutable variable instances. Frozen variables are themselves pytree nodes, so the resulting tree will be safe to pass through JAX transformations if all variables are frozen.

Parameters:
  • tree – A tree containing variables.

  • predicate – A function that returns True for variables that should be frozen. If None, all variables will be frozen.

Returns:

A copy of tree but with all variables (or those selected by predicate) replaced by equivalent frozen instances.