jit_wrapper

jit_wrapper#

Utilities for JIT compilation of Penzai models.

Directly transforming a Penzai model with jax.jit is allowed, but it makes the model difficult to manipulate because the resulting function is an opaque closure. This module provides wrappers that preserve the structure of the original model, and re-express JIT compilation using Penzai conventions.

The intended use of these wrappers is to enable interactive exploration of models in e.g. Colab notebooks, while still taking advantage of JIT compilation.

Classes

Jitted

Wraps a pure layer to run under JIT compliation.