variables#
Support for named pockets of mutable/shared state within a JAX pytree.
A variable is a mutable Python object that can be stored in a JAX pytree, and which contains a JAX pytree. It cannot directly be passed through JAX transformations; instead, it must be extracted before the transformation and re-inserted afterwards. (Eventually, variables may be extended to natively support JAX transformations once the necessary JAX APIs are available.)
Variables are used to simplify the handling of parameters and mutable state in Penzai models, giving an eager interface that follows ordinary Python semantics, while still allowing mutable state to be safely managed for JAX.
Operations on variables include:
“Unbinding” them, which extracts each variable and replaces it with it’s “variable slot” (a placeholder for a variable).
“Freezing” them, which converts a variable into a frozen variable value, which is itself a pytree node.
“Unfreezing”, which converts a frozen variable back into a (new) mutable variable.
“Binding” them, which re-inserts either mutable or frozen variables into the pytree in place of their corresponding variable slots.
Every variable must have a unique “slot” value, which uniquely identifies it within a particular JAX pytree. The same variable Python object may appear in multiple places in the same pytree, but if there are two different variable objects with the same slot value, this will cause an error when variable values are unbound.
Passing Penzai variables through JAX transformations usually involves a combination of these steps. For instance, to take gradients with respect to parameters, you can unbind and freeze them, then take gradients w.r.t. those frozen values, re-binding them inside the function being differentiated. To “functionalize” a stateful operation, you can bind temporary variables, then unbind them afterward.
Most Penzai models use two particular types of variable:
Parameter
: A model parameter variable, which can be modified using gradient descent but isn’t modified while the model runs,StateVariable
: A state variable, which can be modified while the model runs and can be used to store mutable state.
Parameters and state variables are implemented similarly, but are kept separate because they may be treated differently by model combinators and user code. (For instance, when JIT-compiling a particular sublayer, we can often assume that parameters do not change, even if variables might.)
A Penzai variable is somewhat similar to an NNX Variable (from flax.nnx
) but
provides a more restricted interface; this allows it to integrate with JAX’s
pytrees without tracking a full object graph of Python dependencies.
Classes
Base class for all variables. |
|
Base class for all variable slots. |
|
Base class for all frozen variables. |
|
A label for a StateVariable that is unique based on its Python object ID. |
|
Base implementation of a variable with a label, value, and metadata. |
|
The value of a basic labeled variable, as a frozen JAX pytree. |
|
A model parameter variable. |
|
A slot for a parameter in a model. |
|
The value of a Parameter, as a frozen JAX pytree. |
|
A label for a StateVariable that is unique within some scope. |
|
A mutable state variable. |
|
A slot for a parameter in a model. |
|
The value of a StateVariable, as a frozen JAX pytree. |
Functions
|
Binds variables (mutable or frozen) into the variable slots in a pytree. |
|
Version of |
|
Version of |
|
Replaces each variable in a pytree with a frozen copy. |
|
Context manager for using scoped auto-generated StateVariable labels. |
Version of |
|
Version of |
|
Unbinds variables from a pytree, inserting variable slots in their place. |
|
|
Variable-aware version of |
Exceptions
Raised when attempting to access the value of an unbound variable. |
|
Raised when a Variable label is used by multiple Variables. |