named_axes

named_axes#

A lightweight and minimal implementation of named axes.

As argued by “Tensors Considered Harmful”, relying on axis indices for complex tensor operations can be brittle and difficult to read. This has led to a number of proposals for indexing axes by name instead of by position. However, due to the large API surface for NDArray manipulation, building a fully-featured named axis implementation requires making named-axis versions of many individual operations.

This module provides a lightweight implementation of named axes using a “locally positional” style. The key idea is to reuse positional-axis operations in their original form, but use local bindings of positional axes to the underlying named axes; positional axes then become aliases for particular named axes within a local context.

To see how this works, suppose we want to implement dot-product attention between queries and keys. We’d start with two named arrays:

queries = pz.nx.wrap(...).tag("batch", "query_pos", "heads", "embed")
keys = pz.nx.wrap(...).tag("batch", "key_pos", "heads", "embed")

We then contract them against each other, discarding the “embed” dimension and broadcasting jointly over the others (e.g. “bqhf, bkhf -> bqkh” in einsum notation). In our “locally positional” style, we could write this as:

dot_prods = nmap(jnp.dot)(queries.untag("embed"), keys.untag("embed"))

Here jnp.dot is called with two one-axis views of queries and keys, respectively. (More specifically, it is called with jax.vmap tracers that have a single logical axis and three implicitly-broadcasted axes.) We could just as easily use our own function:

def my_dot(a, b):
  print("a:", a)
  print("b:", b)
  print("a.shape:", a.shape)
  print("b.shape:", b.shape)
  return jnp.dot(a, b)

dot_prods = nmap(my_dot)(queries.untag("embed"), keys.untag("embed"))

We can similarly apply softmax over one of the axes:

attn_weights = nmap(jax.nn.softmax)(
    dot_prods.untag("key_pos")).tag("key_pos")

In this case, we need to “tag” the positional axis produced by softmax with a name, and we choose to give it the same name as the original axis.

One advantage of the locally-positional style is that it does not require wrapping/modifying any of the functions in the numpy/JAX API to take axis names; instead, the primitives are written in terms of ordinary positional-axis logic. This means that the full API surface for named axes can be very small. It also means that it’s easy to “drop down” into positional-axis code and do more complex modifications (e.g. slicing, updating) without losing the readability or flexibility of named-axis code.

The locally-positional style is fairly similar to the notation used in the paper “Named Tensor Notation” (Chiang, Rush, and Barak, 2022), in which ordinary mathematical notation is extended with subscripts to identify which axis or axes they should operate over. In both cases, any names that do NOT appear as part of the operation are implicitly vectorized over. The primary difference is that named axes are specified (by untag) separately for each argument instead of being necessarily shared; this simplifies operations that act over different names for each argument or that produce new axis names as outputs.

For more information, see the named axis tutorial in penzai/notebooks.

Classes

NamedArray

A multidimensional array with a combination of positional and named axes.

NamedArrayBase

Base class for named arrays and their transposed views.

NamedArrayView

A possibly-transposed view of an array with positional and named axes.

TmpPosAxisMarker

A marker object used to temporarily assign names to positional axes.

Functions

arange(name, start[, stop, step, dtype])

Convenience function to create a range along a named axis.

concatenate(arrays, axis_name)

Concatenates a sequence of named arrays along a named axis.

full(named_shape, fill_value[, dtype])

Constructs a full named array with a given shape.

is_namedarray(value)

Returns True if this is a NamedArray or NamedArrayView.

nmap(fun)

Automatically vectorizes fun over named axes of NamedArray inputs.

ones(named_shape[, dtype])

Constructs a named array of ones with a given shape.

random_split(key, named_shape)

Splits a PRNG key into a NamedArray of PRNG keys with the given names.

stack(arrays, axis_name)

Stacks a sequence of named arrays along a named axis.

unstack(array, axis_name)

Splits a named array across a given named axis.

wrap(array, *names)

Wraps a positional array as a NamedArray.

zeros(named_shape[, dtype])

Constructs a named array of zeros with a given shape.