Copyright 2024 The Penzai Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Jitting and Sharding Penzai Models#
Penzai works with JAX’s standard function transformations, including JIT-compilation and array sharding. However, because Penzai includes support for mutable variables inside the model, some care must be taken to ensure you apply them in ways that JAX can understand!
This notebook walks through some of the common aspects of JIT-compilation and sharding as they apply to Penzai tools and Penzai models. It assumes some basic familiarity with JAX’s JIT compilation and distributed array systems.
Setup#
Before we can get started in earnest, we need to set up the environment.
Imports#
To run this notebook, you need a Python environment with penzai
and its dependencies installed.
In Colab or Kaggle, you can install it using the following command:
try:
import penzai
except ImportError:
!pip install penzai[notebook]
from __future__ import annotations
import jax
import jax.numpy as jnp
import treescope
import penzai
from penzai import pz
from penzai.models import transformer
Setting up Penzai#
For this tutorial, we’ll enable Treescope (Penzai’s companion pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. We’ll also enable automatic array visualization, which also makes it easy to visualize array shardings.
treescope.basic_interactive_setup(autovisualize_arrays=True)
We’ll assume this notebook is running on a backend with eight devices. If needed, you can force JAX to treat the CPU backend as multiple devices using
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
pz.show(jax.local_devices())
assert jax.local_device_count() == 8
JIT-Compiling Penzai Models#
Penzai model objects themselves are always JAX PyTrees. However, in addition to arrays and arraylike leaves, Penzai models can also include two types of “variable” leaves: pz.Parameter
and pz.StateVariable
. These are currently not directly supported by jax.jit
.
For example, consider the following (somewhat contrived) model, which has a learnable parameter and an incrementing counter:
@pz.pytree_dataclass
class CounterLayer(pz.nn.Layer):
counter: pz.StateVariable[int]
def __call__(self, x, **_side_inputs):
self.counter.value += 1
return (x, self.counter.value)
model = pz.nn.Sequential([
pz.nn.Linear.from_config(
name="linear",
init_base_rng=jax.random.PRNGKey(0),
input_axes={"features": 8},
output_axes={"features_out": 8},
),
CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
])
model
model(pz.nx.ones({"features": 8}))
model(pz.nx.ones({"features": 8}))
To JIT-compile a Penzai model, you have three options:
The “functional API”: A set of Penzai tools to help you manipulate variable states using pure functions and JAX PyTrees.
pz.variable_jit
: A convenience wrapper aroundjax.jit
that also works for PyTrees containingpz.Parameter
andpz.StateVariable
.toolshed.jit_wrapper.Jitted
: A model combinator that acts like an ordinaryLayer
, but always runs underjax.jit
(usingpz.variable_jit
around its__call__
method).
The “Functional API”#
Each of Penzai’s variables comes in three forms:
Mutable variables (
pz.Parameter
andpz.StateVariable
), which are Python objects whose.value
attribute can be modified freely,Frozen variable values (
pz.ParameterValue
andpz.StateVariableValue
), which are immutable JAX PyTree objects that are safe to pass through JAX transforms,Variable slots (
pz.ParameterSlot
andpz.StateVariableSlot
), which are empty placeholders that indicate locations of variables in a larger tree.
For full control over JIT compilation, you can manually convert variables from their mutable form to their immutable form when crossing JAX transform boundaries. The relevant functions:
pz.unbind_variables
(and type-specific variantspz.unbind_params
andpz.unbind_state_vars
): Extracts and deduplicates variables, returning a tree of variable slots along with the deduplicated variables.pz.bind_variables
: Re-inserts variables into variable slots.Parameter.freeze()
andStateVariable.freeze()
: Converts a mutable variable into an immutable value.ParameterValue.unfreeze_as_copy()
andStateVariableValue.unfreeze_as_copy()
: Converts an immutable value back into a (new) mutable variable.
For instance, for our example model above, we can use pz.unbind_variables
and .freeze()
to extract the mutable parts:
model_with_slots, all_vars = pz.unbind_variables(model)
pz.show("model_with_slots:", model_with_slots)
pz.show("all_vars:", all_vars)
frozen_vars = [var.freeze() for var in all_vars]
pz.show("frozen_vars:", frozen_vars)
We can then define a pure function that re-binds these variables, and call it under jax.jit
:
@jax.jit
def rebinding_call(model_with_slots, frozen_vars, arg):
# Make temporary mutable copies:
new_vars = [var.unfreeze_as_copy() for var in frozen_vars]
# Re-attach them to the model:
model = pz.bind_variables(model_with_slots, new_vars)
# Run it:
result = model(arg)
# Extract and re-freeze the variables:
refrozen_vars = [var.freeze() for var in new_vars]
return result, refrozen_vars
result, new_frozen_vars = rebinding_call(
model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)
We can then update the old variables with their new values:
for var, new_value in zip(all_vars, new_frozen_vars):
var.update(new_value)
To make this a bit less verbose, pz.nn.Layer
has a method .stateless_call(vars, ...)
that makes temporary mutable copies of its input variables, like rebinding_call
. So, we could have equivalently written the following:
@jax.jit
def rebinding_call_2(model_with_slots, frozen_vars, arg):
result, refrozen_vars = model_with_slots.stateless_call(frozen_vars, arg)
return result, refrozen_vars
result, new_frozen_vars = rebinding_call_2(
model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)
If you want to JIT-compile your model initializer, you can do this using the functional API:
@jax.jit
def functional_init(init_base_rng):
model = pz.nn.Sequential([
pz.nn.Linear.from_config(
name="linear",
init_base_rng=init_base_rng,
input_axes={"features": 8},
output_axes={"features_out": 8},
),
CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
])
# Unbind and also freeze all variables:
return pz.unbind_variables(model, freeze=True)
model_with_slots, init_var_values = functional_init(jax.random.PRNGKey(0))
# Re-bind variables and also make them mutable again:
model = pz.bind_variables(
model_with_slots, init_var_values, unfreeze_as_copy=True
)
model
pz.variable_jit
#
If you don’t want to use the functional API directly, you can instead use pz.variable_jit
, which is a wrapper around jax.jit
that allows the function arguments to contain pz.Parameter
and pz.StateVariable
in addition to arrays, and handles updating their values for you. For instance, you could write:
@pz.variable_jit
def jitted_call(model, arg):
return model(arg)
jitted_call(model, pz.nx.ones({"features": 8}))
jitted_call(model, pz.nx.ones({"features": 8}))
Note that pz.variable_jit
does not support returning variables from the jitted computation, so it can’t be used to JIT-compile model initialization. It also does not support “closing over” global references to variable objects defined outside of the function. Every variable used by the function inside pz.variable_jit
must have been passed in as an input argument.
jit_wrapper.Jitted
#
pz.variable_jit
works for top-level functions, but sometimes you may want to JIT-compile a specific part of a Penzai model, or compile the forward pass without having to use an indirect jitted_call
function. For this purpose, Penzai provides a layer wrapper Jitted
in penzai.toolshed.jit_wrapper
, which JIT-compiles its forward pass when called.
To use it, you can simply wrap your model in jit_wrappers.Jitted
and then call it as normal:
from penzai.toolshed import jit_wrapper
jit_model = jit_wrapper.Jitted(model)
jit_model
jit_model(pz.nx.ones({"features": 8}))
You can also insert Jitted
around any sublayer of the model, e.g.
jit_model_2 = (
pz.select(model)
.at_instances_of(pz.nn.Linear | CounterLayer)
.apply(jit_wrapper.Jitted)
)
jit_model_2
jit_model_2(pz.nx.ones({"features": 8}))
Note that the Jitted
wrapper is just an ordinary Penzai layer, and you can still pull back out the original model:
jit_model.body == model