Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Named Axes in Penzai#

As argued by “Tensors Considered Harmful”, relying on axis indices for complex tensor operations can be brittile and difficult to read. Penzai provides a lightweight implementation of named axes to make it easier to understand the operations performed by neural networks.

Penzai’s named axis system is based on a “locally positional” programming style, which avoid the need to make named-axis versions of the full JAX API. The key idea is to reuse positional-axis operations in their original form, and then allow named axes to be transformed into positional axes and vice versa. Penzai’s named axis system is also closely integrated into the treescope pretty-printer and array visualizer.

This notebook walks through how named axes work in Penzai and how to use them in Penzai models.

Setup#

Before we can get started in earnest, we need to set up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import traceback

import jax
import jax.numpy as jnp
import numpy as np
import penzai
from penzai import pz

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

We’ll also enable automatic array visualization, which makes it easy to visualize named arrays.

pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

The Locally-Positional Style#

Penzai’s named axis library is defined in penzai.core.named_axes, which is aliased to pz.nx for easier access. The primary object in Penzai’s named axis system is the NamedArray. A NamedArray wraps an ordinary jax.Array, but assigns names to (a subset of) its axes. These names are local to each array.

You can convert a regular array to a NamedArray by calling wrap. This just wraps the array, but doesn’t actually assign any names.

array = pz.nx.wrap(jnp.arange(3*4).reshape((3, 4)).astype(jnp.float32))
print("Positional shape:", array.positional_shape)
print("Named shape:", array.named_shape)

Penzai’s array autovisualizer will automatically show the values inside a NamedArray:

array

To bind names to the axes of a NamedArray, you can call tag, with one name per positional axis. This returns a new NamedArray, with names bound to those axes.

array2 = array.tag("foo", "bar")
print("Positional shape:", array2.positional_shape)
print("Named shape:", array2.named_shape)
%%autovisualize None
pz.select(array2).at_pytree_leaves().show_value()

In this case, automatic array visualization shows you the named axes:

array2

Operations on NamedArrays always act only on the positional axes, and are vectorized (or “lifted”) over the named axes. If you want to apply an operation to a named axis, you can turn it back into a positional axis using untag:

array3 = array2.untag("bar")
print("Positional shape:", array3.positional_shape)
print("Named shape:", array3.named_shape)
array3

Internally, a NamedArray is just a PyTree dataclass object that stores an ordinary JAX array and some metadata. You can poke around at it by pretty printing it directly:

%%autovisualize None
pz.select(array2).at_pytree_leaves().show_value()
%%autovisualize None

# (array3 is actually a NamedArrayView, which is discussed later)
pz.select(array3).at_pytree_leaves().show_value()

This means that JAX transformations like jax.jit work directly with NamedArrays. On the other hand, most JAX primitive operations don’t directly accept NamedArrays:

try:
  jnp.sum(array3)
except:
  traceback.print_exc(1)

Instead, you can use pz.nx.nmap to transform any JAX function so that it handles NamedArray inputs. Within an nmap-ed function, each NamedArray is replaced with a JAX tracer whose shape matches the original named array’s positional_shape.

Conceptually, nmap acts very similarly to JAX’s vmap or xmap. However, instead of specifying which axes you want to map over explicitly, these axes are inferred from the arguments themselves. This means you can call the function the same way you would without nmap.

def print_and_sum(value):
  jax.debug.print("print_and_sum called with value of shape {x.shape}:\n{x}", x=value)
  return jnp.sum(value)
print("Positional shape:", array.positional_shape, "Named shape:", array.named_shape)
pz.nx.nmap(print_and_sum)(array)
print("Positional shape:", array2.positional_shape, "Named shape:", array2.named_shape)
pz.nx.nmap(print_and_sum)(array2)
print("Positional shape:", array3.positional_shape, "Named shape:", array3.named_shape)
pz.nx.nmap(print_and_sum)(array3)

This means that it’s possible to run any ordinary JAX function over an arbitrary set of axes of a NamedArray, by first using untag to expose those axes as positional, then using nmap to map over the other names. For instance, we can sum over the “foo” axis of array2:

pz.nx.nmap(jnp.sum)(array2.untag("foo"))

Or the “bar” axis:

pz.nx.nmap(jnp.sum)(array2.untag("bar"))

More complex transformations are possible too. For instance, here’s how we might compute dot-product attention:

queries = pz.nx.wrap(
    jax.random.normal(jax.random.key(1), (10, 4, 16)),
).tag("query_seq", "heads", "embed")

keys = pz.nx.wrap(
    jax.random.normal(jax.random.key(2), (10, 4, 16)),
).tag("key_seq", "heads", "embed")

{
    "queries": queries,
    "keys": keys,
}
# Contract the two "embed" dimensions:
attention_logits = pz.nx.nmap(jnp.dot)(queries.untag("embed"), keys.untag("embed")) / np.sqrt(10)

# Mask out cells where query comes before key:
attention_logits_masked = pz.nx.nmap(jnp.where)(
    pz.nx.wrap(jnp.arange(10)).tag("query_seq")
        < pz.nx.wrap(jnp.arange(10)).tag("key_seq"),
    -1e9,
    attention_logits,
)

# Take a softmax over "key_seq", then name the output axis "key_seq" again:
attention_scores = pz.nx.nmap(jax.nn.softmax)(
    attention_logits_masked.untag("key_seq")
).tag("key_seq")

{
    "attention_logits": attention_logits,
    "attention_logits_masked": attention_logits_masked,
    "attention_scores": attention_scores,
}

You can turn a NamedArray back into an ordinary array using unwrap, as long as it doesn’t have any more positional axes:

attention_scores.untag("query_seq", "key_seq", "heads").unwrap()

For convenience, you can also write this as attention_scores.unwrap("query_seq", "key_seq", "heads"), but the meaning is the same.

Array methods and infix operators#

NamedArrays support most of the same instance methods as ordinary JAX arrays. Just like nmap-ed functions, these wrapped instance methods only act on the positional axes, and are vectorized over the named axes. This means you don’t have to learn any new API options; they always have exactly the same signature that the jax.Array methods do.

For instance, you can use infix operators:

array_a = pz.nx.wrap(
    jax.random.normal(jax.random.key(1), (3, 4)),
).tag("foo", "bar")
array_b = pz.nx.wrap(
    jax.random.normal(jax.random.key(2), (4, 5)),
).tag("bar", "baz")
array_a + array_b

You can also use reduction methods (which reduce over positional axes):

array_a.untag("foo").sum()
array_a.untag("foo").std()

Or slice along positional axes:

array_a.untag("foo")[jnp.array([0, 1, 0, 1, 2, 1, 1])]

One place where the NamedArray API extends the jax.Array API is that it also allows indexing/slicing with dictionaries. Slicing a NamedArray with a dictionary applies the given operations to the named axes instead of the positional ones:

array_a[{"foo": 3, "bar": pz.slice[1:3]}]

The name-based automatic vectorization makes it easy to perform complex indexing operations. For instance, to index an array of log-probabilities with an array of tokens, you can do something like this:

tokens = pz.nx.wrap(jnp.arange(100).reshape(5,20)).tag("batch", "seq")
log_probabilities = pz.nx.wrap(
    jax.random.uniform(jax.random.key(1), (5, 200,))
).tag("batch", "vocabulary")

# Index the vocabulary by the tokens for each batch element:
log_probs_for_each_token = log_probabilities.untag("vocabulary")[tokens]
# or, equivalently: log_probabilities[{"vocabulary": tokens}]
log_probs_for_each_token
log_probs_for_each_token.named_shape

Advanced: PyTrees, batches of NamedArrays, and NamedArrayViews#

Many of Penzai’s named axis operations produce NamedArray objects. As discussed above, these are just immutable PyTree dataclasses that wrap an internal jax.Array and add metadata to it:

%%autovisualize None
array_fully_positional = pz.nx.wrap(jnp.arange(3*4*5).reshape((3, 4, 5)).astype(jnp.float32))
pz.select(array_fully_positional).at_pytree_leaves().show_value()
%%autovisualize None
array_fully_named = array_fully_positional.tag("foo", "bar", "baz")
pz.select(array_fully_named).at_pytree_leaves().show_value()

When dealing with arrays that have mixtures of positional and named axes, they will often be instead represented as a NamedArrayView, which has a bit more bookkeeping to avoid unnecessary device memory transpositions. Both NamedArray and NamedArrayView are subclasses of NamedArrayBase and support all the same methods.

%%autovisualize None
array_partially_untagged = array_fully_named.untag("bar")
pz.select(array_partially_untagged).at_pytree_leaves().show_value()

Some higher-order JAX transformations, such as jax.lax.scan, apply over prefix axes of PyTree leaves. It’s possible to combine this with NamedArrays, but this requires a bit of care:

  • Positional axes of NamedArrays always appear at the front of the internal data_array. It’s allowed to add new axes to the front of data_array, or to remove existing positional axes, as long as you don’t remove an axis that already has a name.

  • Positional axes of NamedArrayViews can appear anywhere. In general, it’s NOT allowed to directly manipulate the shape of the data_array of a NamedArrayView; these should be used only as temporary objects.

For instance, it’s fine to stack or slice NamedArrays using tree_map:

stacked = jax.tree_util.tree_map(lambda a: jnp.stack([a, -a]), array_fully_positional)
print("Positional shape:", stacked.positional_shape)
print("Named shape:", stacked.named_shape)
stacked
sliced = jax.tree_util.tree_map(lambda a: a[0, 1], array_fully_positional)
print("Positional shape:", sliced.positional_shape)
print("Named shape:", sliced.named_shape)
sliced

But it’s not fine to stack NamedArrayViews:

bad = jax.tree_util.tree_map(lambda a: jnp.stack([a, -a]), array_partially_untagged)
try:
  bad.check_valid()
except:
  traceback.print_exc(1)

If you have a NamedArrayView and you need to access its positional axes using PyTree manipulation (e.g. for tree_map or scan), you should call with_positional_prefix to transform it into a NamedArray (possibly transposing its internal data array):

array_partially_untagged.with_positional_prefix()

Another thing to watch out for when using control flow like scan is that NamedArray named axes can sometimes appear in different orders along different control flow paths, which can lead to incompatible PyTree structures. You can enforce a specific order using order_as. This converts NamedArrayViews into NamedArrays if necessary and also guarantees the named axes appear in this specific sequence, making it easier to ensure outputs have the same PyTree structure.

array_partially_untagged.order_as("baz", "foo")

You can also easily transpose a named array’s data array to make it match another named array, which is useful if you want to pass them through JAX transformations that require the same PyTree structure (e.g. JAX.jvp):

other = array_fully_named.untag("bar").with_positional_prefix()
other
array_partially_untagged.order_like(other)

Other utility methods#

Most NamedArray manipulation can be done directly using pz.nx.nmap, .tag, and .untag. However, there are also a few additional convenience methods to make it easier to work with named arrays.

Construction#

You can build simple NamedArrays using pz.nx.ones, pz.nx.zeros, pz.nx.full, and pz.nx.arange, which are named wrappers around the corresponding JAX functions:

pz.nx.ones({"a": 3, "b": 4})
pz.nx.zeros({"a": 3, "b": 4})
pz.nx.full({"a": 3, "b": 4}, 7)
pz.nx.arange("foo", 10)

This can be especially useful in combination with automatically-vectorized elementwise operators:

# Creates a two-dimensional mask indexed by "foo" and "bar":
pz.nx.arange("foo", 10) > pz.nx.arange("bar", 10)

Broadcasting#

You can broadcast an array using .broadcast_to or .broadcast_like:

# Broadcasts the positional axes:
pz.nx.arange("foo", 10).broadcast_to((3,))
# Adds a named axis:
pz.nx.arange("foo", 10).broadcast_to(named_shape={"bar": 4})
# Can also include existing axes:
pz.nx.arange("foo", 10).broadcast_to(named_shape={"foo": 10, "bar": 4})
# Can also broadcast like another array:
pz.nx.arange("foo", 10).broadcast_like(pz.nx.arange("bar", 10))

Stacking / Concatenation#

You can concatenate and stack named arrays together along named axes:

pz.nx.stack([
    pz.nx.zeros({"foo": 10}),
    pz.nx.arange("foo", 10),
    pz.nx.full({"foo": 10}, 9),
], "bar")
pz.nx.concatenate([
    pz.nx.zeros({"foo": 10, "bar": 3}),
    pz.nx.ones({"foo": 10, "bar": 7}),
], "bar")

Tagging / untagging prefixes#

To make it easier to manipulate prefix axes, there are utilities that allow you to tag or untag subsets of axes at a time:

arr = pz.nx.wrap(jnp.ones([10,11,12]))
print(arr.positional_shape, arr.named_shape)
arr
# Tag the first two positional axess
arr2 = arr.tag_prefix("foo", "bar")
print(arr2.positional_shape, arr2.named_shape)
arr2
# Untag one positional axis
arr3 = arr2.untag_prefix("foo")
print(arr3.positional_shape, arr3.named_shape)
arr3

Random keys#

The random_split utility allows you to split a named array of PRNG keys along new named axes:

keys = pz.nx.random_split(
    pz.nx.wrap(jax.random.key(10)),
    {"batch": 16}
)
print(keys.positional_shape, keys.named_shape)

pz.nx.nmap(jax.random.normal)(keys, shape=(4,))

Comparision with other named axis systems#

penzai.named_axes vs JAX’s axis names (vmap/pmap/xmap)#

JAX already includes a form of named axes through vmap/pmap with the axis_name argument, and also has a (deprecated) named axis system jax.xmap as described in “Named axes and easy-to-revise parallelism with xmap. However, although penzai’s implementation of named axes uses jax.vmap under the hood, from an API perspective our approach can be viewed as the opposite (or perhaps the dual) of JAX’s current named axis style:

  • Top-level arrays:

    • In vmap (or xmap), the array objects you interact with at the top level are indexed positionally only. Names are only bound within an inner function, and the inner function uses names to access those axes.

    • With penzai.named_axes, the array objects you interact with at the top level are NamedArray objects, with explicit names. However, internal operations can use ordinary positional syntax for the axes they care about. (One advantage of this approach is that it makes it super easy to visualize arrays with named axes. This is also similar to the approach taken by the xarray library.)

  • Mapping behavior:

    • In vmap (or xmap), you specify which axes to vectorize over while transforming the function. If you want to map over more axes, you either wrap your function in more layers of vmap or modify the args to xmap.

    • With penzai.named_axes, the axis names determine which axes get vectorized over. You can use the same nmap-wrapped function regardless of how many axes you want to vectorize over, or even call it with ordinary jax.Arrays, without having to worry about how many named axes it has.

  • Overall style:

    • In vmap (or xmap), most of the data flow occurs within a single transformed context. Individual operations (collectives) break out of this context to retrieve named axes where necessary.

    • With penzai.named_axes, most of the data flow occurs outside of a transformed context. Instead, individual operations are transformed, and tag and untag are used to manipulate named and positional axes.

penzai.core.named_axes vs Haliax#

Penzai’s named axis system was partially inspired by a similar system in the JAX library Haliax. Haliax also defines a NamedArray PyTree, which wraps a positional array and gives it named-axis semantics, but there are a few design differences.

  • API wrapping vs user transformations: Haliax takes the approach of defining named-axis variants of common numpy/JAX API functions, such as dot, mean, argsort, etc, under the haliax namespace. These wrapped functions take axis name arguments instead of axis index arguments. This is convenient but also requires separately defining a haliax wrapper for each type of operation you want to run.

    In contrast, Penzai intentionally avoids defining named-axis variants of ordinary numpy and JAX functions (with a few exceptions like infix operators and named_arange). Instead, the user is responsible for transforming the ordinary positional versions into named-axis versions at the call site. This reduces the complexity of penzai itself and also makes it possible to lift any existing JAX function to operate over named arrays, without having to explicitly add it to the penzai library.

    This also leads to a mental model for array axis manipulation that is closer to the ordinary numpy positional style. For instance, Haliax has a handwritten utility for splitting one axis into two, which looks something like:

    Foo = haliax.Axis("foo", 3)
    Bar = haliax.Axis("bar", 4)
    FooAndBar = haliax.Axis("foo_and_bar", 12)
    haliax.split(my_array, FooAndBar, (Foo, Bar))
    

    Penzai doesn’t provide a utility like this, but it’s straightforward to do this operation by temporarily dropping into positional mode and using the ordinary numpy reshape function:

    my_array.untag("foo_and_bar").reshape((3, 4)).tag("foo", "bar")
    
  • Sized axes vs strings: Haliax named arrays and operations use Axis objects to associate axis names with sizes, e.g. haliax.Axis("batch", 32). On the other hand, Penzai named arrays and penzai.pz.nx.nmap just operate on string axis names.

    Haliax’s approach is useful when code is written according to Haliax’s conventions:

    • Core API functions (like haliax.zeros((FooAxis, BarAxis))) can directly create arrays with the correct shape and correct axis names, without needing the sizes to be specified separately.

    • User-defined functions can take Axis arguments as inputs, forward them to the functions they call, and also inspect their size without having to cross-reference them with a specific input array.

    • NamedArrays can check their runtime input shapes and make sure they match the expectations of the user.

    However, for penzai, coupling axis names with their sizes comes with a few disadvantages:

    • Most core JAX API functions are called in the locally-positional style using nmap instead of directly taking an axis as an argument, so we don’t benefit from storing the array size as part of the axis name.

    • Neural networks in penzai store their configuration (e.g. the set of axes they act on) as dataclass attributes, which can lead to redundancy if every axis name also includes a size. This redundancy can make it difficult to inspect and modify existing models, since axis sizes have to be kept in sync across the entire network architecture.

    As such, penzai uses the simpler system. Layers and operations that need axis sizes typically take as an argument a dictionary mapping axis names to their sizes instead, or infer the axis sizes at runtime using named_array.named_shape on their parameters or their inputs.