jit_wrapper#
Utilities for JIT compilation of Penzai models.
Directly transforming a Penzai model with jax.jit
is allowed, but it makes
the model difficult to manipulate because the resulting function is an opaque
closure. This module provides wrappers that preserve the structure of the
original model, and re-express JIT compilation using Penzai conventions.
The intended use of these wrappers is to enable interactive exploration of models in e.g. Colab notebooks, while still taking advantage of JIT compilation.
Classes
Wraps a pure layer to run under JIT compliation. |