Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Jitting and Sharding Penzai Models#

Penzai is designed to be compatible with JAX’s standard function transformations, including JIT-compilation and array sharding. If you’re already familiar with JIT compilation and distributed arrays in JAX, you shouldn’t have to learn anything fundamentally new to apply it to Penzai! But Penzai does provide some utilities to make it easier to construct and manipulate shardings for Penzai models.

This notebook walks through some of the common aspects of JIT-compilation and sharding as they apply to Penzai tools and Penzai models. It assumes some basic familiarity with JAX’s JIT compilation and distributed array systems.

Setup#

Before we can get started in earnest, we need to set up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import dataclasses

import jax
import jax.numpy as jnp
import optax
import penzai
from penzai import pz
from penzai.example_models import gemma
from penzai.example_models import simple_mlp
from penzai.toolshed import basic_training

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. We’ll also enable automatic array visualization, which also makes it easy to visualize array shardings.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

We’ll assume this notebook is running on a backend with eight devices. If needed, you can force JAX to treat the CPU backend as multiple devices using

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
pz.show(jax.local_devices())
assert jax.local_device_count() == 8

JIT-Compiling Penzai Models#

By convention, Penzai models are always JAX PyTrees with only arraylike leaves. This means you can always JIT-compile a function that takes a Penzai model as input or returns one as output, using ordinary jax.jit.

For instance, suppose we have the following Penzai model definition:

mlp_def = simple_mlp.MLP.from_config([8, 32, 32, 8])
mlp_def

We could JIT-compile the initializer for it:

@jax.jit
def init_my_model(mlp_def):
  return pz.nn.initialize_parameters(mlp_def, jax.random.key(0))
mlp = init_my_model(mlp_def)
mlp

And we can just as easily JIT-compile a loss function that uses it:

# Just for demonstration; a real loss function would probably be more complex
# and involve a batch of examples

@jax.jit
def simple_mse_loss(mlp, inputs, target):
  output = mlp(inputs)
  diffs = (output - target).untag("features").unwrap()
  return jnp.sum(jnp.square(diffs))
simple_mse_loss(mlp, pz.nx.ones({"features": 8}), pz.nx.zeros({"features": 8}))

Note that Penzai models store their parameters inside the model, not in a separate parameter dictionary. This means you probably don’t want to do this:

# !!!! PROBABLY NOT WHAT YOU WANT TO DO:
jitted_call = jax.jit(mlp)
jitted_call(some_input)

The reason is that this will “bake in” the parameters of your MLP as constants in the compiled function, so JAX will need to recompile it if you update the parameters of the MLP.

Instead, you can do something like this:

@jax.jit
def jitted_call(mlp, arg):
  return mlp(arg)
jitted_call(mlp, pz.nx.ones({"features": 8}))

To save you the trouble of doing this manually when you want to JIT your model’s __call__, Penzai provides a wrapper that does this automatically:

from penzai.toolshed import jit_wrapper
jitted_mlp = jit_wrapper.Jitted(mlp)
jitted_mlp(pz.nx.ones({"features": 8}))

Jitted is actually just an ordinary Penzai layer. It holds your model inside it as an attribute, and jit-compiles __call__:

@pz.pytree_dataclass
class Jitted(pz.Layer):
  body: pz.LayerLike

  def __call__(self, argument: Any, /) -> Any:
    return jitted_call(self.body, argument)

You can see the model stored inside it as well:

jitted_mlp

It defines __call__ to be JIT-compiled, including itself as a non-static argument. This means that JAX will automatically re-used the cached compiled program if you call multiple Jitted layers with the same structure, even if you update the parameters.

It will also re-compile if you make modifications. For instance, we can freely insert new logic into our “jitted MLP”, and those new functions will run under JIT as well:

@pz.pytree_dataclass
class PrintMyValue(pz.Layer):
  def __call__(self, arg):
    pz.show("Intermediate:", arg)
    return arg
patched_jitted_mlp = (
    pz.select(jitted_mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_before(PrintMyValue())
)
patched_jitted_mlp
patched_jitted_mlp(pz.nx.ones({"features": 8}))

And you can always pull the model back out of the Jitted wrapper:

jitted_mlp.body
assert jitted_mlp.body is mlp

Sharding Basics, and Visualizing Shardings with Treescope#

Penzai’s array autovisualizer supports showing shardings and sharded arrays by default. This section explains the basics of JAX’s distributed array shardings and how you can visualize the different components in Treescope. (See this page for the official documentation of JAX’s sharding system.)

Positional shardings#

At a high level, you can think of a “sharding” as a multidimensional array of device objects, which will be matched with your multidimensional array of data to determine which part of the array ends up on each device. You generally build a sharding by starting with a NumPy array of devices:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
devices

A simple type of sharding is PositionalSharding, which essentially just holds onto these devices and tracks some extra JAX-specific information. If you print out a PositionalSharding in Treescope, it color-codes the devices and shows you their arrangement:

pos_sharding = jax.sharding.PositionalSharding(devices)
pos_sharding

In this case, the sharding has a single positional axis, of length 8. We can use this to shard arrays whose (first) positional axis is a multiple of 8. For instance:

jax.device_put(jnp.ones(16), pos_sharding)

You can click the “Sharded across 8 TPU devices” message to show a visualization of the sharding for this array. When automatic array visualization is enabled, sharding visualizations are automatically added to any array that is sharded or replicated.

We can reshape positional shardings to give them multiple axes:

pos_sharding.reshape((4,2))
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((4,2)))

If you expand the sharding visualization above, you’ll see how the two axes of the array are matched with the two axes of the sharding.

You can also use shardings to indicate that certain parts of the array should be replicated on multiple devices, using replicate:

pos_sharding.reshape((2, 4)).replicate(axis=0)
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((2, 4)).replicate(axis=0))

Each element of an array with a replicated sharding will appear on more than one device. This is visually represented in Treescope using a multicolored pattern.

You can also fully-replicate the array over all of the devices:

pos_sharding.replicate(axis=0)
jax.device_put(jnp.ones([8, 8]), pos_sharding.replicate(axis=0))

Fully-replicated arrays are also identified as such in the sharding summary before being expanded.

Meshes and named shardings#

It is often convenient to refer to different axes of an array of devices by name instead of by position. JAX represents this using the type jax.sharding.Mesh. Conceptually, just as a PositionalSharding is essentially a positional array of devices, a Mesh is essentially a named array of devices, i.e. an array of devices where each axis has a name.

Penzai annotates the device ID arrays of Mesh instances with axis names instead of axis positions:

mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))
mesh

To shard a (positionally-indexed) JAX array using a mesh, you can use jax.sharding.NamedSharding to assign particular axis indices to mesh axis names, like this:

jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo', 'bar'))
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, ('bar', 'foo'), None))
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo'))

Note: Each NamedSharding specifies how to shard an input array’s positional axes, since ordinary JAX arrays only have positional axes. The names in the NamedSharding are just a way to match the positional axes in the array with the corresponding names in the Mesh. For this reason, visualizations of NamedSharding instances are annotated with positional axes, not axis names.

(Penzai already has its own mechanism for binding names to an array’s positional axes: pz.nx.NamedArray. We’ll discuss how to shard Penzai’s NamedArray next.)

Sharding Penzai’s NamedArrays#

Manually sharding NamedArrays#

Fundamentally, there are no changes when applying JAX shardings to Penzai’s NamedArrays. Internally, a NamedArray is just a dataclass PyTree node that contains a JAX array and some axis name annotations, which we can see if we disable automatic array visualization temporarily:

arr = pz.nx.arange("foo", 1, 4) + pz.nx.arange("bar", 0, 4)
# With automatic array visualization enabled:
arr
%%autovisualize None
# ^ With automatic array visualization disabled (and expanding it to show detail)
pz.select(arr).at_instances_of(jax.Array).show_value()

JAX’s sharding system allows you to specify the sharding for a PyTree of arrays by using a matching PyTree of shardings. So, we can build a sharding for this named array by inserting a positional sharding into it:

data_array_sharding = jax.sharding.PositionalSharding(devices).reshape((2,4)).replicate(axis=0)
sharding_for_arr = pz.nx.NamedArray(
    named_axes=arr.named_axes,
    data_array=data_array_sharding,
)
sharding_for_arr

Applying this sharding to the NamedArray shards the data_array attribute (try expanding below):

%%autovisualize lambda a,p: pz.ts.ArrayAutovisualizer()(a, p) if isinstance(a, jax.Array) else None
# (^ this line overrides the autovisualizer to show the sharding of the data array when expanded)

sharded_arr = jax.device_put(arr, sharding_for_arr)
pz.select(sharded_arr).at_instances_of(jax.Array).show_value()

But with normal automatic array visualization, treescope will show you how the named axes are sharded, since that’s usually what you care about when using Penzai models in practice:

sharded_arr

Automatically building shardings for NamedArrays#

To simplify this process, Penzai provides some optional utilities for constructing shardings for NamedArray instances. These utilities take a Mesh, and allow you to map from NamedArray axis names to Mesh axis names across a tree of arrays.

For instance, consider this tree of arrays:

some_array_tree = {
    "one": pz.nx.ones({"a": 4, "b": 8, "c": 6}),
    "two": pz.nx.ones({"a": 8}),
    "three": pz.nx.ones({"b": 4, "d": 12}),
}
some_array_tree

And this mesh:

mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))
mesh

We can assign each named axis in some_array_tree to an axis in the mesh using the name_to_name_sharding utility, which builds a tree of shardings that is compatible with the tree of arrays:

from penzai.toolshed import sharding_util
shardings = sharding_util.name_to_name_sharding(
    some_array_tree,
    mesh,
    axis_name_to_mesh_name={
        "a": "bar",
        "b": "foo",
    },
)
shardings

We can then apply those shardings to the original array tree to shard the corresponding axes:

jax.device_put(some_array_tree, shardings)

Even simpler, if you just want to call device_put you can bundle them into one call:

sharding_util.name_to_name_device_put(
    some_array_tree,
    mesh,
    axis_name_to_mesh_name={
        "a": "bar",
        "b": "foo",
    },
)

If your mesh happens to use the exact same axis names as your arrays, you don’t need the axis_name_to_mesh_name argument:

already_matching_mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('b', 'a'))
sharding_util.name_to_name_device_put(
    some_array_tree,
    already_matching_mesh,
    # axis_name_to_mesh_name inferred as {"a":"a", "b":"b"}
)

Sharding Penzai Models and Training Loops#

Penzai also provides some utilities that are specific to training and using Penzai neural newtork models. These are simple self-contained utilities that can be a good starting point, but you are free to customize them to get lower-level control when needed.

Sharding Parameter Initializers#

Uninitialized Penzai models directly expose all of the parameter initializers to you as attributes inside your model. If you want to customize the sharding of your parameters, you can JIT-compile the initializer with the appropriate sharding, e.g.

sharded_initializer = jax.jit(
  pz.nn.initialize_parameters,
  out_shardings=..., # <- insert your desired sharding specification here
)
params = sharded_initializer(mlp_def, jax.random.key(42))

If you want to infer out_shardings using the axis names of your parameters, you can do that using the helper function initialize_parameters_sharded. This function just traces the initializer to figure out the parameter shapes, infers the right sharding to use, and then runs your initializer accordingly.

For instance, here’s how you could initialize the parameters of a small transformer in a sharded way:

# Using the Gemma model architecture, but very small for demonstration purposes.
tiny_transformer_def = gemma.model_core.GemmaTransformer.from_config(
    gemma.model_core.GemmaTransformerConfig(
        num_heads=2,
        embedding_dim=64,
        projection_dim=16,
        single_kv_head=False,
        mlp_hidden_dim=128,
        num_decoder_blocks=2,
        vocab_size=100,
        parameter_dtype=jnp.float32,
        activation_dtype=jnp.float32,
    )
)
tiny_transformer = sharding_util.initialize_parameters_sharded(
    tiny_transformer_def,
    jax.random.key(42),
    mesh=jax.sharding.Mesh(devices, axis_names=('devices',)),
    axis_name_to_mesh_name={
        # Shard the embedding dimension across devices.
        "embedding": "devices",
    },
)
tiny_transformer

Sharding Training Steps#

You’re encouraged to write your own custom training loop for your use case. However, the basic training step implementation in penzai.toolshed.basic_training does support sharded training.

The easiest way to shard a training loop is to just shard your model parameters and inputs, JIT-compile the training loop, and let JAX figure out how the sharding should propagate. XLA can usually automatically infer a decent sharding for the computation and its outputs.

For instance, here’s how we could write a simple training loop for this tiny transformer:

# Simple loss function for demonstration purposes.
def simplified_xent_loss_fn(model, rng, state, input_examples):
  del rng, state  # Unused.
  # Run the model.
  outputs = model(gemma.model_core.GemmaInputs.from_basic_segments(
      input_examples[{"seq": pz.slice[:-1]}]
  ))
  # Compute log-probabilities along the "vocabulary" axis.
  all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      outputs.untag("vocabulary")
  ).tag("vocabulary")
  # Index by the correct tokens.
  correct_next_tokens = input_examples[{"seq": pz.slice[1:]}]
  correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]
  # Take averages.
  loss = -correct_log_probs.untag("batch", "seq").unwrap().mean()
  return loss, None, {"loss": loss}
train_step = basic_training.build_train_step_fn(
    simplified_xent_loss_fn,
    jit=True,
    # donate_params_and_state=True,  # <- Uncomment to allow XLA memory optimizations.
)
train_state = basic_training.TrainState.initial_state(
    model=tiny_transformer,
    optimizer_def=optax.adamw(5e-5, weight_decay=0.01),
    root_rng=jax.random.key(42),
)
# Take a training step (with a dummy input in this case).
input_examples = pz.nx.ones({"batch": 8, "seq": 20}, dtype=jnp.int32)
updated_train_state, outs = train_step(train_state, input_examples=input_examples)
# Show the updated parameters.
pz.select(updated_train_state.model).at_instances_of(pz.nn.Parameter).get_sequence()

If you inspect the parameters above, you will likely see that they are still sharded along the embedding axis (because XLA will likely infer that that keeping the same sharding is the easiest).

If you prefer, however, you can also manually specify what shardings you want to use, and the training step function will respect them. For instance, we can explicitly request that the model parameters and optimizer states be sharded across the “features” axis, and the inputs be sharded across the “batch” axis.

mesh = jax.sharding.Mesh(devices, axis_names=('devices',))
train_step = basic_training.build_train_step_fn(
    simplified_xent_loss_fn,
    jit=True,
    # Shard inputs over "batch" axis.
    input_kwarg_shardings={
        "input_examples": sharding_util.name_to_name_sharding(
            input_examples,
            mesh,
            axis_name_to_mesh_name={"batch": "devices"},
        ),
    },
    # Shard model and optimizer params over "embedding" axis.
    train_state_shardings=sharding_util.name_to_name_sharding(
        train_state,
        mesh,
        axis_name_to_mesh_name={"embedding": "devices"},
        ignore_unnamed_arrays=True,
    ),
    # donate_params_and_state=True,  # <- Uncomment to allow XLA memory optimizations.
)
# Take a training step (with a dummy input in this case).
input_examples = pz.nx.ones({"batch": 8, "seq": 20}, dtype=jnp.int32)
updated_train_state, outs = train_step(train_state, input_examples=input_examples)
# Show the updated parameters.
pz.select(updated_train_state.model).at_instances_of(pz.nn.Parameter).get_sequence()

Adding Sharding Constraints to Models#

You may want more control over the way that intermediate values are sharded. JAX allows you to control this using jax.lax.with_sharding_constraint, which forces a particular value to have a particular sharding.

In a Penzai model, sharding constraints can be enforced by simply inserting new layers into the model at the points where you want to constrain the shardings. Penzai’s sharding_util module provides two simple classes ConstrainSharding and ConstrainShardingByName for this purpose, defined as

@pz.pytree_dataclass
class ConstrainSharding(pz.Layer):
  sharding: PyTreeOfShardings = field(metadata={"pytree_node": False})
  def __call__(self, tree: Any) -> Any:
    return jax.lax.with_sharding_constraint(tree, self.sharding)

@pz.pytree_dataclass
class ConstrainShardingByName(pz.Layer):
  mesh: jax.sharding.Mesh = field(metadata={"pytree_node": False})
  axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = (
      field(default=None, metadata={"pytree_node": False})
  )
  def __call__(self, tree: PyTreeOfNamedArrays) -> PyTreeOfNamedArrays:
    return jax.lax.with_sharding_constraint(
        tree,
        name_to_name_sharding(tree, self.mesh, self.axis_name_to_mesh_name),
    )

You can insert them into the model using logic like this:

# Make sure it's sharded over the batch axis after each block.
tiny_transformer_constrained = (
    pz.select(tiny_transformer)
    .at_instances_of(gemma.model_core.GemmaTransformerBlock)
    .insert_after(sharding_util.ConstrainShardingByName(
        mesh, axis_name_to_mesh_name={"batch": "devices"}
    ))
)
# Visualize the constraints:
pz.select(tiny_transformer_constrained).at_instances_of(sharding_util.ConstrainShardingByName)

This gives you a version of the model whose intermediates will be sharded in the way you specified.

train_state = basic_training.TrainState.initial_state(
    model=tiny_transformer_constrained,
    optimizer_def=optax.adamw(5e-5, weight_decay=0.01),
    root_rng=jax.random.key(42),
)
mesh = jax.sharding.Mesh(devices, axis_names=('devices',))
train_step = basic_training.build_train_step_fn(
    simplified_xent_loss_fn,
    jit=True,
    # Shard inputs over "batch" axis.
    input_kwarg_shardings={
        "input_examples": sharding_util.name_to_name_sharding(
            input_examples,
            mesh,
            axis_name_to_mesh_name={"batch": "devices"},
        ),
    },
    # Shard model and optimizer params over "embedding" axis.
    train_state_shardings=sharding_util.name_to_name_sharding(
        train_state,
        mesh,
        axis_name_to_mesh_name={"embedding": "devices"},
        ignore_unnamed_arrays=True,
    ),
    # donate_params_and_state=True,  # <- Uncomment to allow XLA memory optimizations.
)
# Take a training step (with a dummy input in this case).
input_examples = pz.nx.ones({"batch": 8, "seq": 20}, dtype=jnp.int32)
updated_train_state, outs = train_step(train_state, input_examples=input_examples)

If you later want to change how your model’s intermediates are sharded, you can simply remove these constraints:

tiny_transformer_unconstrained = (
    pz.select(tiny_transformer_constrained)
    .at_instances_of(sharding_util.ConstrainShardingByName)
    .remove_from_parent()
)

# No more constraints:
(
    pz.select(tiny_transformer_unconstrained)
    .at_instances_of(sharding_util.ConstrainShardingByName)
    .assert_count_is(0)
)

Aside: Parameter checkpointing for sharded models#

Note: To make it easier to save and restore parameters from checkpoints even if you’ve inserted sharding constraints (or made other modifications), we recommend only checkpointing the dictionary of model parameters, not the full model structure. This is how the basic_training.TrainState stores the parameters internally:

train_state

You can manually extract the parameter dictionary from a model like this:

param_dict = {
    param.name: param.value
    for param in (
        pz.select(tiny_transformer)
        .at_instances_of(pz.nn.Parameter)
        .get_sequence()
    )
}
param_dict

And later restore them using something like this:

restored = (
    pz.select(tiny_transformer)
    .at_instances_of(pz.nn.Parameter)
    .apply(
        lambda param: dataclasses.replace(param, value=param_dict[param.name])
    )
)
restored

If you haven’t yet ininitialized your parameters, you can do something similar to initialize the UninitializedParameters directly using saved values:

(
    pz.select(tiny_transformer_def)
    .at_instances_of(pz.nn.UninitializedParameter)
    .apply(
        lambda uninit: uninit.initialize_with_value(param_dict[uninit.name])
    )
)

You can also use this to build a PyTree with the same shape as your model parameter dictionary without initializing them first:

# Produces a structure containing jax.ShapeDtypeStruct
param_dict_structure = {
    uninit.name: uninit.as_empty_parameter().value
    for uninit in (
        pz.select(tiny_transformer_def)
        .at_instances_of(pz.nn.UninitializedParameter)
        .get_sequence()
    )
}
param_dict_structure

This, in turn, could be used to build a sharding specification:

param_dict_sharding = sharding_util.name_to_name_sharding(
    param_dict_structure,
    mesh,
    axis_name_to_mesh_name={"embedding": "devices"},
)
param_dict_sharding

You can also wrap the NamedSharding leaves in a jax.ShapeDtypeStruct:

param_dict_sharding_structs = sharding_util.name_to_name_sharding(
    param_dict_structure,
    mesh,
    axis_name_to_mesh_name={"embedding": "devices"},
    as_shape_dtype_struct=True,  # <- Wraps shardings in ShapeDtypeStruct
)
param_dict_sharding_structs

If you’re using orbax.checkpoint for your parameters, you can configure it so that it restores the parameters directly using this sharding, using something like

loaded_param_dict = checkpointer.restore(
    ckpt_path,
    args=orbax.checkpoint.args.StandardRestore(param_dict_sharding_structs),
)