Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Open in Colab Open in Kaggle

Jitting and Sharding Penzai Models#

Penzai is designed to be compatible with JAX’s standard function transformations, including JIT-compilation and array sharding. If you’re already familiar with JIT compilation and distributed arrays in JAX, you shouldn’t have to learn anything fundamentally new to apply it to Penzai! But Penzai does provide some utilities to make it easier to construct and manipulate shardings for Penzai models.

This notebook walks through some of the common aspects of JIT-compilation and sharding as they apply to Penzai tools and Penzai models. It assumes some basic familiarity with JAX’s JIT compilation and distributed array systems.


Before we can get started in earnest, we need to set up the environment.


To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import dataclasses

import jax
import jax.numpy as jnp
import optax
import penzai
from penzai import pz
from penzai.example_models import gemma
from penzai.example_models import simple_mlp
from penzai.toolshed import basic_training

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. We’ll also enable automatic array visualization, which also makes it easy to visualize array shardings.


We’ll assume this notebook is running on a backend with eight devices. If needed, you can force JAX to treat the CPU backend as multiple devices using

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
assert jax.local_device_count() == 8

JIT-Compiling Penzai Models#

By convention, Penzai models are always JAX PyTrees with only arraylike leaves. This means you can always JIT-compile a function that takes a Penzai model as input or returns one as output, using ordinary jax.jit.

For instance, suppose we have the following Penzai model definition:

mlp_def = simple_mlp.MLP.from_config([8, 32, 32, 8])

We could JIT-compile the initializer for it:

def init_my_model(mlp_def):
  return pz.nn.initialize_parameters(mlp_def, jax.random.key(0))
mlp = init_my_model(mlp_def)

And we can just as easily JIT-compile a loss function that uses it:

# Just for demonstration; a real loss function would probably be more complex
# and involve a batch of examples

def simple_mse_loss(mlp, inputs, target):
  output = mlp(inputs)
  diffs = (output - target).untag("features").unwrap()
  return jnp.sum(jnp.square(diffs))
simple_mse_loss(mlp, pz.nx.ones({"features": 8}), pz.nx.zeros({"features": 8}))

Note that Penzai models store their parameters inside the model, not in a separate parameter dictionary. This means you probably don’t want to do this:

jitted_call = jax.jit(mlp)

The reason is that this will “bake in” the parameters of your MLP as constants in the compiled function, so JAX will need to recompile it if you update the parameters of the MLP.

Instead, you can do something like this:

def jitted_call(mlp, arg):
  return mlp(arg)
jitted_call(mlp, pz.nx.ones({"features": 8}))

To save you the trouble of doing this manually when you want to JIT your model’s __call__, Penzai provides a wrapper that does this automatically:

from penzai.toolshed import jit_wrapper
jitted_mlp = jit_wrapper.Jitted(mlp)
jitted_mlp(pz.nx.ones({"features": 8}))

Jitted is actually just an ordinary Penzai layer. It holds your model inside it as an attribute, and jit-compiles __call__:

class Jitted(pz.Layer):
  body: pz.LayerLike

  def __call__(self, argument: Any, /) -> Any:
    return jitted_call(self.body, argument)

You can see the model stored inside it as well:


It defines __call__ to be JIT-compiled, including itself as a non-static argument. This means that JAX will automatically re-used the cached compiled program if you call multiple Jitted layers with the same structure, even if you update the parameters.

It will also re-compile if you make modifications. For instance, we can freely insert new logic into our “jitted MLP”, and those new functions will run under JIT as well:

class PrintMyValue(pz.Layer):
  def __call__(self, arg):"Intermediate:", arg)
    return arg
patched_jitted_mlp = (
patched_jitted_mlp(pz.nx.ones({"features": 8}))

And you can always pull the model back out of the Jitted wrapper:

assert jitted_mlp.body is mlp

Sharding Basics, and Visualizing Shardings with Treescope#

Penzai’s array autovisualizer supports showing shardings and sharded arrays by default. This section explains the basics of JAX’s distributed array shardings and how you can visualize the different components in Treescope. (See this page for the official documentation of JAX’s sharding system.)

Positional shardings#

At a high level, you can think of a “sharding” as a multidimensional array of device objects, which will be matched with your multidimensional array of data to determine which part of the array ends up on each device. You generally build a sharding by starting with a NumPy array of devices:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))

A simple type of sharding is PositionalSharding, which essentially just holds onto these devices and tracks some extra JAX-specific information. If you print out a PositionalSharding in Treescope, it color-codes the devices and shows you their arrangement:

pos_sharding = jax.sharding.PositionalSharding(devices)

In this case, the sharding has a single positional axis, of length 8. We can use this to shard arrays whose (first) positional axis is a multiple of 8. For instance:

jax.device_put(jnp.ones(16), pos_sharding)

You can click the “Sharded across 8 TPU devices” message to show a visualization of the sharding for this array. When automatic array visualization is enabled, sharding visualizations are automatically added to any array that is sharded or replicated.

We can reshape positional shardings to give them multiple axes:

jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((4,2)))

If you expand the sharding visualization above, you’ll see how the two axes of the array are matched with the two axes of the sharding.

You can also use shardings to indicate that certain parts of the array should be replicated on multiple devices, using replicate:

pos_sharding.reshape((2, 4)).replicate(axis=0)
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((2, 4)).replicate(axis=0))

Each element of an array with a replicated sharding will appear on more than one device. This is visually represented in Treescope using a multicolored pattern.

You can also fully-replicate the array over all of the devices:

jax.device_put(jnp.ones([8, 8]), pos_sharding.replicate(axis=0))

Fully-replicated arrays are also identified as such in the sharding summary before being expanded.

Meshes and named shardings#

It is often convenient to refer to different axes of an array of devices by name instead of by position. JAX represents this using the type jax.sharding.Mesh. Conceptually, just as a PositionalSharding is essentially a positional array of devices, a Mesh is essentially a named array of devices, i.e. an array of devices where each axis has a name.

Penzai annotates the device ID arrays of Mesh instances with axis names instead of axis positions:

mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))

To shard a (positionally-indexed) JAX array using a mesh, you can use jax.sharding.NamedSharding to assign particular axis indices to mesh axis names, like this:

jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo', 'bar'))
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, ('bar', 'foo'), None))
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo'))

Note: Each NamedSharding specifies how to shard an input array’s positional axes, since ordinary JAX arrays only have positional axes. The names in the NamedSharding are just a way to match the positional axes in the array with the corresponding names in the Mesh. For this reason, visualizations of NamedSharding instances are annotated with positional axes, not axis names.

(Penzai already has its own mechanism for binding names to an array’s positional axes: pz.nx.NamedArray. We’ll discuss how to shard Penzai’s NamedArray next.)

Sharding Penzai’s NamedArrays#

Manually sharding NamedArrays#

Fundamentally, there are no changes when applying JAX shardings to Penzai’s NamedArrays. Internally, a NamedArray is just a dataclass PyTree node that contains a JAX array and some axis name annotations, which we can see if we disable automatic array visualization temporarily:

arr = pz.nx.arange("foo", 1, 4) + pz.nx.arange("bar", 0, 4)
# With automatic array visualization enabled:
%%autovisualize None
# ^ With automatic array visualization disabled (and expanding it to show detail)

JAX’s sharding system allows you to specify the sharding for a PyTree of arrays by using a matching PyTree of shardings. So, we can build a sharding for this named array by inserting a positional sharding into it:

data_array_sharding = jax.sharding.PositionalSharding(devices).reshape((2,4)).replicate(axis=0)
sharding_for_arr = pz.nx.NamedArray(

Applying this sharding to the NamedArray shards the data_array attribute (try expanding below):

%%autovisualize lambda a,p: pz.ts.ArrayAutovisualizer()(a, p) if isinstance(a, jax.Array) else None
# (^ this line overrides the autovisualizer to show the sharding of the data array when expanded)

sharded_arr = jax.device_put(arr, sharding_for_arr)

But with normal automatic array visualization, treescope will show you how the named axes are sharded, since that’s usually what you care about when using Penzai models in practice:


Automatically building shardings for NamedArrays#

To simplify this process, Penzai provides some optional utilities for constructing shardings for NamedArray instances. These utilities take a Mesh, and allow you to map from NamedArray axis names to Mesh axis names across a tree of arrays.

For instance, consider this tree of arrays:

some_array_tree = {
    "one": pz.nx.ones({"a": 4, "b": 8, "c": 6}),
    "two": pz.nx.ones({"a": 8}),
    "three": pz.nx.ones({"b": 4, "d": 12}),

And this mesh:

mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))

We can assign each named axis in some_array_tree to an axis in the mesh using the name_to_name_sharding utility, which builds a tree of shardings that is compatible with the tree of arrays:

from penzai.toolshed import sharding_util
shardings = sharding_util.name_to_name_sharding(
        "a": "bar",
        "b": "foo",

We can then apply those shardings to the original array tree to shard the corresponding axes:

jax.device_put(some_array_tree, shardings)

Even simpler, if you just want to call device_put you can bundle them into one call:

        "a": "bar",
        "b": "foo",

If your mesh happens to use the exact same axis names as your arrays, you don’t need the axis_name_to_mesh_name argument:

already_matching_mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('b', 'a'))
    # axis_name_to_mesh_name inferred as {"a":"a", "b":"b"}

Sharding Penzai Models and Training Loops#

Penzai also provides some utilities that are specific to training and using Penzai neural newtork models. These are simple self-contained utilities that can be a good starting point, but you are free to customize them to get lower-level control when needed.

Sharding Parameter Initializers#

Uninitialized Penzai models directly expose all of the parameter initializers to you as attributes inside your model. If you want to customize the sharding of your parameters, you can JIT-compile the initializer with the appropriate sharding, e.g.

sharded_initializer = jax.jit(
  out_shardings=..., # <- insert your desired sharding specification here
params = sharded_initializer(mlp_def, jax.random.key(42))

If you want to infer out_shardings using the axis names of your parameters, you can do that using the helper function initialize_parameters_sharded. This function just traces the initializer to figure out the parameter shapes, infers the right sharding to use, and then runs your initializer accordingly.

For instance, here’s how you could initialize the parameters of a small transformer in a sharded way:

# Using the Gemma model architecture, but very small for demonstration purposes.
tiny_transformer_def = gemma.model_core.GemmaTransformer.from_config(
tiny_transformer = sharding_util.initialize_parameters_sharded(
    mesh=jax.sharding.Mesh(devices, axis_names=('devices',)),
        # Shard the embedding dimension across devices.
        "embedding": "devices",