variable_jit#
- penzai.core.variables.variable_jit(fun, *, donate_variables: bool = False, **jit_kwargs)[source]#
Variable-aware version of
jax.jit.This function is like
jax.jit, but adds support for Variables as leaves of the input pytree(s).Limitations:
Closed-over Variables are not supported. All Variables must be passed as arguments.
Variables always have an unspecified sharding.
Variables should not be included in
static_argnumsorstatic_argnamesof the jitted function.The keyword argument
"__penzai_variables"is used to track Variables and should not be used directly.
If you run into issues with this wrapper or if you need more control, consider using
jax.jitdirectly and manually unbinding the Variables before the transformation.- Parameters:
fun – The function to be jitted.
donate_variables – Whether to donate Variables to the jitted function.
**jit_kwargs – Additional arguments to pass to
jax.jit. Note: Any donated keyword arguments must be configured usingdonate_argnamesinstead ofdonate_argnums.
- Returns:
A jitted version of
fun