variable_jit

Contents

variable_jit#

penzai.core.variables.variable_jit(fun, *, donate_variables: bool = False, **jit_kwargs)[source]#

Variable-aware version of jax.jit.

This function is like jax.jit, but adds support for Variables as leaves of the input pytree(s).

Limitations:

  • Closed-over Variables are not supported. All Variables must be passed as arguments.

  • Variables always have an unspecified sharding.

  • Variables should not be included in static_argnums or static_argnames of the jitted function.

  • The keyword argument "__penzai_variables" is used to track Variables and should not be used directly.

If you run into issues with this wrapper or if you need more control, consider using jax.jit directly and manually unbinding the Variables before the transformation.

Parameters:
  • fun – The function to be jitted.

  • donate_variables – Whether to donate Variables to the jitted function.

  • **jit_kwargs – Additional arguments to pass to jax.jit. Note: Any donated keyword arguments must be configured using donate_argnames instead of donate_argnums.

Returns:

A jitted version of fun