Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Jitting and Sharding Penzai Models#

Penzai works with JAX’s standard function transformations, including JIT-compilation and array sharding. However, because Penzai includes support for mutable variables inside the model, some care must be taken to ensure you apply them in ways that JAX can understand!

This notebook walks through some of the common aspects of JIT-compilation and sharding as they apply to Penzai tools and Penzai models. It assumes some basic familiarity with JAX’s JIT compilation and distributed array systems.

Setup#

Before we can get started in earnest, we need to set up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import jax
import jax.numpy as jnp
import treescope
import penzai
from penzai import pz
from penzai.models import transformer

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s companion pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. We’ll also enable automatic array visualization, which also makes it easy to visualize array shardings.

treescope.basic_interactive_setup(autovisualize_arrays=True)

We’ll assume this notebook is running on a backend with eight devices. If needed, you can force JAX to treat the CPU backend as multiple devices using

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
pz.show(jax.local_devices())
assert jax.local_device_count() == 8

JIT-Compiling Penzai Models#

Penzai model objects themselves are always JAX PyTrees. However, in addition to arrays and arraylike leaves, Penzai models can also include two types of “variable” leaves: pz.Parameter and pz.StateVariable. These are currently not directly supported by jax.jit.

For example, consider the following (somewhat contrived) model, which has a learnable parameter and an incrementing counter:

@pz.pytree_dataclass
class CounterLayer(pz.nn.Layer):
  counter: pz.StateVariable[int]

  def __call__(self, x, **_side_inputs):
    self.counter.value += 1
    return (x, self.counter.value)

model = pz.nn.Sequential([
    pz.nn.Linear.from_config(
        name="linear",
        init_base_rng=jax.random.PRNGKey(0),
        input_axes={"features": 8},
        output_axes={"features_out": 8},
    ),
    CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
])

model
model(pz.nx.ones({"features": 8}))
model(pz.nx.ones({"features": 8}))

To JIT-compile a Penzai model, you have three options:

  • The “functional API”: A set of Penzai tools to help you manipulate variable states using pure functions and JAX PyTrees.

  • pz.variable_jit: A convenience wrapper around jax.jit that also works for PyTrees containing pz.Parameter and pz.StateVariable.

  • toolshed.jit_wrapper.Jitted: A model combinator that acts like an ordinary Layer, but always runs under jax.jit (using pz.variable_jit around its __call__ method).

The “Functional API”#

Each of Penzai’s variables comes in three forms:

  • Mutable variables (pz.Parameter and pz.StateVariable), which are Python objects whose .value attribute can be modified freely,

  • Frozen variable values (pz.ParameterValue and pz.StateVariableValue), which are immutable JAX PyTree objects that are safe to pass through JAX transforms,

  • Variable slots (pz.ParameterSlot and pz.StateVariableSlot), which are empty placeholders that indicate locations of variables in a larger tree.

For full control over JIT compilation, you can manually convert variables from their mutable form to their immutable form when crossing JAX transform boundaries. The relevant functions:

  • pz.unbind_variables (and type-specific variants pz.unbind_params and pz.unbind_state_vars): Extracts and deduplicates variables, returning a tree of variable slots along with the deduplicated variables.

  • pz.bind_variables: Re-inserts variables into variable slots.

  • Parameter.freeze() and StateVariable.freeze(): Converts a mutable variable into an immutable value.

  • ParameterValue.unfreeze_as_copy() and StateVariableValue.unfreeze_as_copy(): Converts an immutable value back into a (new) mutable variable.

For instance, for our example model above, we can use pz.unbind_variables and .freeze() to extract the mutable parts:

model_with_slots, all_vars = pz.unbind_variables(model)
pz.show("model_with_slots:", model_with_slots)
pz.show("all_vars:", all_vars)
frozen_vars = [var.freeze() for var in all_vars]
pz.show("frozen_vars:", frozen_vars)

We can then define a pure function that re-binds these variables, and call it under jax.jit:

@jax.jit
def rebinding_call(model_with_slots, frozen_vars, arg):
  # Make temporary mutable copies:
  new_vars = [var.unfreeze_as_copy() for var in frozen_vars]
  # Re-attach them to the model:
  model = pz.bind_variables(model_with_slots, new_vars)
  # Run it:
  result = model(arg)
  # Extract and re-freeze the variables:
  refrozen_vars = [var.freeze() for var in new_vars]
  return result, refrozen_vars
result, new_frozen_vars = rebinding_call(
    model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)

We can then update the old variables with their new values:

for var, new_value in zip(all_vars, new_frozen_vars):
  var.update(new_value)

To make this a bit less verbose, pz.nn.Layer has a method .stateless_call(vars, ...) that makes temporary mutable copies of its input variables, like rebinding_call. So, we could have equivalently written the following:

@jax.jit
def rebinding_call_2(model_with_slots, frozen_vars, arg):
  result, refrozen_vars = model_with_slots.stateless_call(frozen_vars, arg)
  return result, refrozen_vars
result, new_frozen_vars = rebinding_call_2(
    model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)

If you want to JIT-compile your model initializer, you can do this using the functional API:

@jax.jit
def functional_init(init_base_rng):
  model = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="linear",
          init_base_rng=init_base_rng,
          input_axes={"features": 8},
          output_axes={"features_out": 8},
      ),
      CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
  ])
  # Unbind and also freeze all variables:
  return pz.unbind_variables(model, freeze=True)
model_with_slots, init_var_values = functional_init(jax.random.PRNGKey(0))
# Re-bind variables and also make them mutable again:
model = pz.bind_variables(
    model_with_slots, init_var_values, unfreeze_as_copy=True
)
model

pz.variable_jit#

If you don’t want to use the functional API directly, you can instead use pz.variable_jit, which is a wrapper around jax.jit that allows the function arguments to contain pz.Parameter and pz.StateVariable in addition to arrays, and handles updating their values for you. For instance, you could write:

@pz.variable_jit
def jitted_call(model, arg):
  return model(arg)
jitted_call(model, pz.nx.ones({"features": 8}))
jitted_call(model, pz.nx.ones({"features": 8}))

Note that pz.variable_jit does not support returning variables from the jitted computation, so it can’t be used to JIT-compile model initialization. It also does not support “closing over” global references to variable objects defined outside of the function. Every variable used by the function inside pz.variable_jit must have been passed in as an input argument.

jit_wrapper.Jitted#

pz.variable_jit works for top-level functions, but sometimes you may want to JIT-compile a specific part of a Penzai model, or compile the forward pass without having to use an indirect jitted_call function. For this purpose, Penzai provides a layer wrapper Jitted in penzai.toolshed.jit_wrapper, which JIT-compiles its forward pass when called.

To use it, you can simply wrap your model in jit_wrappers.Jitted and then call it as normal:

from penzai.toolshed import jit_wrapper
jit_model = jit_wrapper.Jitted(model)
jit_model
jit_model(pz.nx.ones({"features": 8}))

You can also insert Jitted around any sublayer of the model, e.g.

jit_model_2 = (
    pz.select(model)
    .at_instances_of(pz.nn.Linear | CounterLayer)
    .apply(jit_wrapper.Jitted)
)
jit_model_2
jit_model_2(pz.nx.ones({"features": 8}))

Note that the Jitted wrapper is just an ordinary Penzai layer, and you can still pull back out the original model:

jit_model.body == model

Sharding Basics, and Visualizing Shardings with Treescope#

Penzai’s array autovisualizer supports showing shardings and sharded arrays by default. This section explains the basics of JAX’s distributed array shardings and how you can visualize the different components in Treescope. (See this page for the official documentation of JAX’s sharding system.)

Positional shardings#

At a high level, you can think of a “sharding” as a multidimensional array of device objects, which will be matched with your multidimensional array of data to determine which part of the array ends up on each device. You generally build a sharding by starting with a NumPy array of devices:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
devices

A simple type of sharding is PositionalSharding, which essentially just holds onto these devices and tracks some extra JAX-specific information. If you print out a PositionalSharding in Treescope, it color-codes the devices and shows you their arrangement:

pos_sharding = jax.sharding.PositionalSharding(devices)
pos_sharding

In this case, the sharding has a single positional axis, of length 8. We can use this to shard arrays whose (first) positional axis is a multiple of 8. For instance:

jax.device_put(jnp.ones(16), pos_sharding)

You can click the “Sharded across 8 TPU devices” message to show a visualization of the sharding for this array. When automatic array visualization is enabled, sharding visualizations are automatically added to any array that is sharded or replicated.

We can reshape positional shardings to give them multiple axes:

pos_sharding.reshape((4,2))
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((4,2)))

If you expand the sharding visualization above, you’ll see how the two axes of the array are matched with the two axes of the sharding.

You can also use shardings to indicate that certain parts of the array should be replicated on multiple devices, using replicate:

pos_sharding.reshape((2, 4)).replicate(axis=0)
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((2, 4)).replicate(axis=0))

Each element of an array with a replicated sharding will appear on more than one device. This is visually represented in Treescope using a multicolored pattern.

You can also fully-replicate the array over all of the devices:

pos_sharding.replicate(axis=0)
jax.device_put(jnp.ones([8, 8]), pos_sharding.replicate(axis=0).reshape((1, 1)))

Fully-replicated arrays are also identified as such in the sharding summary before being expanded.

Meshes and named shardings#

It is often convenient to refer to different axes of an array of devices by name instead of by position. JAX represents this using the type jax.sharding.Mesh. Conceptually, just as a PositionalSharding is essentially a positional array of devices, a Mesh is essentially a named array of devices, i.e. an array of devices where each axis has a name.

Penzai annotates the device ID arrays of Mesh instances with axis names instead of axis positions:

mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))
mesh

To shard a (positionally-indexed) JAX array using a mesh, you can use jax.sharding.NamedSharding to assign particular axis indices to mesh axis names, like this:

jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo', 'bar'))
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, ('bar', 'foo'), None))
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo'))

Note: Each NamedSharding specifies how to shard an input array’s positional axes, since ordinary JAX arrays only have positional axes. The names in the NamedSharding are just a way to match the positional axes in the array with the corresponding names in the Mesh. For this reason, visualizations of NamedSharding instances are annotated with positional axes, not axis names.

(Penzai already has its own mechanism for binding names to an array’s positional axes: pz.nx.NamedArray. We’ll discuss how to shard Penzai’s NamedArray next.)

Sharding Penzai’s NamedArrays#

Manually sharding NamedArrays#

Fundamentally, there are no changes when applying JAX shardings to Penzai’s NamedArrays. Internally, a NamedArray is just a dataclass PyTree node that contains a JAX array and some axis name annotations, which we can see if we disable automatic array visualization temporarily:

arr = pz.nx.arange("foo", 1, 4) + pz.nx.arange("bar", 0, 4)
# With automatic array visualization enabled:
arr
%%autovisualize None
# ^ With automatic array visualization disabled (and expanding it to show detail)
pz.select(arr).at_instances_of(jax.Array).show_value()

JAX’s sharding system allows you to specify the sharding for a PyTree of arrays by using a matching PyTree of shardings. So, we can build a sharding for this named array by inserting a positional sharding into it:

data_array_sharding = jax.sharding.PositionalSharding(devices).reshape((2,4)).replicate(axis=0)
sharding_for_arr = pz.nx.NamedArray(
    named_axes=arr.named_axes,
    data_array=data_array_sharding,
)
sharding_for_arr

Applying this sharding to the NamedArray shards the data_array attribute (try expanding below):

%%autovisualize lambda a,p: treescope.ArrayAutovisualizer()(a, p) if isinstance(a, jax.Array) else None
# (^ this line overrides the autovisualizer to show the sharding of the data array when expanded)

sharded_arr = jax.device_put(arr, sharding_for_arr)
pz.select(sharded_arr).at_instances_of(jax.Array).show_value()

But with normal automatic array visualization, treescope will show you how the named axes are sharded, since that’s usually what you care about when using Penzai models in practice:

sharded_arr

Automatically building shardings for NamedArrays#

To simplify this process, Penzai provides some optional utilities for constructing shardings for NamedArray instances. These utilities take a Mesh, and allow you to map from NamedArray axis names to Mesh axis names across a tree of arrays.

For instance, consider this tree of arrays:

some_array_tree = {
    "one": pz.nx.ones({"a": 4, "b": 8, "c": 6}),
    "two": pz.nx.ones({"a": 8}),
    "three": pz.nx.ones({"b": 4, "d": 12}),
}
some_array_tree

And this mesh:

mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))
mesh

We can assign each named axis in some_array_tree to an axis in the mesh using the name_to_name_sharding utility, which builds a tree of shardings that is compatible with the tree of arrays:

from penzai.toolshed import sharding_util
shardings = sharding_util.name_to_name_sharding(
    some_array_tree,
    mesh,
    axis_name_to_mesh_name={
        "a": "bar",
        "b": "foo",
    },
)
shardings

We can then apply those shardings to the original array tree to shard the corresponding axes:

jax.device_put(some_array_tree, shardings)

Even simpler, if you just want to call device_put you can bundle them into one call:

sharding_util.name_to_name_device_put(
    some_array_tree,
    mesh,
    axis_name_to_mesh_name={
        "a": "bar",
        "b": "foo",
    },
)

If your mesh happens to use the exact same axis names as your arrays, you don’t need the axis_name_to_mesh_name argument:

already_matching_mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('b', 'a'))
sharding_util.name_to_name_device_put(
    some_array_tree,
    already_matching_mesh,
    # axis_name_to_mesh_name inferred as {"a":"a", "b":"b"}
)

Sharding Penzai Models and Training Loops#

Penzai also provides some utilities that are specific to training and using Penzai neural newtork models. These are simple self-contained utilities that can be a good starting point, but you are free to customize them to get lower-level control when needed.

Sharding Parameter Initializers#

If you already know the shardings for your model parameters, you can pass those, you can JIT-compile parameter optimization using something like

def functional_init(init_base_rng):
  model = ...
  return pz.unbind_variables(model, freeze=True)

sharded_init = jax.jit(
  functional_init,
  out_shardings=..., # <- insert your desired sharding specification here
)

model = pz.bind_variables(*sharded_init(rng))

If you want to infer out_shardings using the axis names of your parameters, you can do that using the helper function sharding_util.sharded_init. This function just traces the initializer to figure out the parameter shapes, infers the right sharding to use, and then runs your initializer accordingly.

For instance, here’s how you could initialize the parameters of a small transformer in a sharded way:

from penzai.toolshed import sharding_util
from penzai.models.transformer.variants import llamalike_common
# Very small transformer config, for demo purposes
config = llamalike_common.LlamalikeTransformerConfig(
    num_kv_heads=2,
    query_head_multiplier=1,
    embedding_dim=64,
    projection_dim=16,
    mlp_hidden_dim=128,
    num_decoder_blocks=2,
    vocab_size=100,
    mlp_variant="geglu_approx",
    rope_wavelength=10_000,
    tie_embedder_and_logits=True,
    use_layer_stack=False,
    parameter_dtype=jnp.float32,
    activation_dtype=jnp.float32,
)

tiny_transformer = sharding_util.sharded_init(
    llamalike_common.build_llamalike_transformer,
    config=config,
    init_base_rng=jax.random.key(42),
    mesh=jax.sharding.Mesh(devices, axis_names=('devices',)),
    axis_name_to_mesh_name={
        # Shard the embedding dimension across devices.
        "embedding": "devices",
    },
)
tiny_transformer