Copyright 2024 The Penzai Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Jitting and Sharding Penzai Models (V2 API)#
Penzai is designed to be compatible with JAX’s standard function transformations, including JIT-compilation and array sharding. If you’re already familiar with JIT compilation and distributed arrays in JAX, you shouldn’t have to learn anything fundamentally new to apply it to Penzai! But Penzai does provide some utilities to make it easier to construct and manipulate shardings for Penzai models.
This notebook walks through some of the common aspects of JIT-compilation and sharding as they apply to Penzai tools and Penzai models. It assumes some basic familiarity with JAX’s JIT compilation and distributed array systems.
Note
This tutorial uses the V2 neural network API, defined in pz.experimental.v2.
Setup#
Before we can get started in earnest, we need to set up the environment.
Imports#
To run this notebook, you need a Python environment with penzai and its dependencies installed.
In Colab or Kaggle, you can install it using the following command:
try:
import penzai
except ImportError:
!pip install penzai[notebook]
from __future__ import annotations
import dataclasses
import jax
import jax.numpy as jnp
import optax
import penzai
from penzai.experimental.v2 import pz
from penzai.experimental.v2.models import transformer
from penzai.experimental.v2.models import simple_mlp
from penzai.experimental.v2.toolshed import basic_training
Setting up Penzai#
For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. We’ll also enable automatic array visualization, which also makes it easy to visualize array shardings.
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())
We’ll assume this notebook is running on a backend with eight devices. If needed, you can force JAX to treat the CPU backend as multiple devices using
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
pz.show(jax.local_devices())
assert jax.local_device_count() == 8
JIT-Compiling Penzai Models#
Penzai model objects themselves are always JAX PyTrees. However, in addition to arrays and arraylike leaves, Penzai models can also include two types of “variable” leaves: pz.Parameter and pz.StateVariable. These are currently not directly supported by jax.jit.
For example, consider the following (somewhat contrived) model, which has a learnable parameter and an incrementing counter:
@pz.pytree_dataclass
class CounterLayer(pz.nn.Layer):
counter: pz.StateVariable[int]
def __call__(self, x, **_side_inputs):
self.counter.value += 1
return (x, self.counter.value)
model = pz.nn.Sequential([
pz.nn.Linear.from_config(
name="linear",
init_base_rng=jax.random.PRNGKey(0),
input_axes={"features": 8},
output_axes={"features_out": 8},
),
CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
])
model
model(pz.nx.ones({"features": 8}))
model(pz.nx.ones({"features": 8}))
To JIT-compile a Penzai model, you have three options:
The “functional API”: A set of Penzai tools to help you manipulate variable states using pure functions and JAX PyTrees.
pz.variable_jit: A convenience wrapper aroundjax.jitthat also works for PyTrees containingpz.Parameterandpz.StateVariable.toolshed.jit_wrapper.Jitted: A model combinator that acts like an ordinaryLayer, but always runs underjax.jit(usingpz.variable_jitaround its__call__method).
The “Functional API”#
Each of Penzai’s variables comes in three forms:
Mutable variables (
pz.Parameterandpz.StateVariable), which are Python objects whose.valueattribute can be modified freely,Frozen variable values (
pz.ParameterValueandpz.StateVariableValue), which are immutable JAX PyTree objects that are safe to pass through JAX transforms,Variable slots (
pz.ParameterSlotandpz.StateVariableSlot), which are empty placeholders that indicate locations of variables in a larger tree.
For full control over JIT compilation, you can manually convert variables from their mutable form to their immutable form when crossing JAX transform boundaries. The relevant functions:
pz.unbind_variables(and type-specific variantspz.unbind_paramsandpz.unbind_state_vars): Extracts and deduplicates variables, returning a tree of variable slots along with the deduplicated variables.pz.bind_variables: Re-inserts variables into variable slots.Parameter.freeze()andStateVariable.freeze(): Converts a mutable variable into an immutable value.ParameterValue.unfreeze_as_copy()andStateVariableValue.unfreeze_as_copy(): Converts an immutable value back into a (new) mutable variable.
For instance, for our example model above, we can use pz.unbind_variables and .freeze() to extract the mutable parts:
model_with_slots, all_vars = pz.unbind_variables(model)
pz.show("model_with_slots:", model_with_slots)
pz.show("all_vars:", all_vars)
frozen_vars = [var.freeze() for var in all_vars]
pz.show("frozen_vars:", frozen_vars)
We can then define a pure function that re-binds these variables, and call it under jax.jit:
@jax.jit
def rebinding_call(model_with_slots, frozen_vars, arg):
# Make temporary mutable copies:
new_vars = [var.unfreeze_as_copy() for var in frozen_vars]
# Re-attach them to the model:
model = pz.bind_variables(model_with_slots, new_vars)
# Run it:
result = model(arg)
# Extract and re-freeze the variables:
refrozen_vars = [var.freeze() for var in new_vars]
return result, refrozen_vars
result, new_frozen_vars = rebinding_call(
model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)
We can then update the old variables with their new values:
for var, new_value in zip(all_vars, new_frozen_vars):
var.update(new_value)
To make this a bit less verbose, pz.nn.Layer has a method .stateless_call(vars, ...) that makes temporary mutable copies of its input variables, like rebinding_call. So, we could have equivalently written the following:
@jax.jit
def rebinding_call_2(model_with_slots, frozen_vars, arg):
result, refrozen_vars = model_with_slots.stateless_call(frozen_vars, arg)
return result, refrozen_vars
result, new_frozen_vars = rebinding_call_2(
model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)
If you want to JIT-compile your model initializer, you can do this using the functional API:
@jax.jit
def functional_init(init_base_rng):
model = pz.nn.Sequential([
pz.nn.Linear.from_config(
name="linear",
init_base_rng=init_base_rng,
input_axes={"features": 8},
output_axes={"features_out": 8},
),
CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
])
# Unbind and also freeze all variables:
return pz.unbind_variables(model, freeze=True)
model_with_slots, init_var_values = functional_init(jax.random.PRNGKey(0))
# Re-bind variables and also make them mutable again:
model = pz.bind_variables(
model_with_slots, init_var_values, unfreeze_as_copy=True
)
model
pz.variable_jit#
If you don’t want to use the functional API directly, you can instead use pz.variable_jit, which is a wrapper around jax.jit that allows the function arguments to contain pz.Parameter and pz.StateVariable in addition to arrays, and handles updating their values for you. For instance, you could write:
@pz.variable_jit
def jitted_call(model, arg):
return model(arg)
jitted_call(model, pz.nx.ones({"features": 8}))
jitted_call(model, pz.nx.ones({"features": 8}))
Note that pz.variable_jit does not support returning variables from the jitted computation, so it can’t be used to JIT-compile model initialization. It also does not support “closing over” global references to variable objects defined outside of the function. Every variable used by the function inside pz.variable_jit must have been passed in as an input argument.
jit_wrapper.Jitted#
pz.variable_jit works for top-level functions, but sometimes you may want to JIT-compile a specific part of a Penzai model, or compile the forward pass without having to use an indirect jitted_call function. For this purpose, Penzai provides a layer wrapper Jitted in penzai.experimental.v2.toolshed.jit_wrapper, which JIT-compiles its forward pass when called.
To use it, you can simply wrap your model in jit_wrappers.Jitted and then call it as normal:
from penzai.experimental.v2.toolshed import jit_wrapper
jit_model = jit_wrapper.Jitted(model)
jit_model
jit_model(pz.nx.ones({"features": 8}))
You can also insert Jitted around any sublayer of the model, e.g.
jit_model_2 = (
pz.select(model)
.at_instances_of(pz.nn.Linear | CounterLayer)
.apply(jit_wrapper.Jitted)
)
jit_model_2
jit_model_2(pz.nx.ones({"features": 8}))
Note that the Jitted wrapper is just an ordinary Penzai layer, and you can still pull back out the original model:
jit_model.body == model