variable_jit#
- penzai.core.variables.variable_jit(fun, *, donate_variables: bool = False, **jit_kwargs)[source]#
Variable-aware version of
jax.jit
.This function is like
jax.jit
, but adds support for Variables as leaves of the input pytree(s).Limitations:
Closed-over Variables are not supported. All Variables must be passed as arguments.
Variables always have an unspecified sharding.
Variables should not be included in
static_argnums
orstatic_argnames
of the jitted function.The keyword argument
"__penzai_variables"
is used to track Variables and should not be used directly.
If you run into issues with this wrapper or if you need more control, consider using
jax.jit
directly and manually unbinding the Variables before the transformation.- Parameters:
fun – The function to be jitted.
donate_variables – Whether to donate Variables to the jitted function.
**jit_kwargs – Additional arguments to pass to
jax.jit
. Note: Any donated keyword arguments must be configured usingdonate_argnames
instead ofdonate_argnums
.
- Returns:
A jitted version of
fun