nmap

Contents

nmap#

penzai.core.named_axes.nmap(fun: Callable[..., Any]) Callable[..., Any][source]#

Automatically vectorizes fun over named axes of NamedArray inputs.

nmap is a “named-axis vectorizing map”. It wraps an ordinary positional-axis-based function so that it accepts NamedArrays as input and produces NamedArrays as output, and vectorizes over all of the named axes, calling the original function with positionally-indexed slices corresponding to each argument’s positional_shape.

Unlike jax.vmap, the axes to vectorize over are inferred automatically from the named axes in the NamedArray / NamedArrayView, rather than being specified as part of the mapping transformation. Specifically, each axis name that appears in any of the arguments is vectorized over jointly across all arguments that include that axis name, and is then included as an axis name in the output. To make an axis visible to fun, you can call untag on the argument and pass the axis name(s) of interest; fun will then see those axes as positional axes instead of mapping over them.

untag and nmap are together the primary ways to apply individual operations to axes of a NamedArray. tag can then be used on the result to re-bind names to positional axes.

Within fun, any mapped-over axes will be accessible using standard JAX collective operations like psum, although doing this is usually unnecessary.

Parameters:

fun – Function to vectorize by name. This can take arbitrary arguments (even non-JAX-arraylike arguments or “static” axis sizes), but must produce a PyTree of JAX ArrayLike outputs.

Returns:

An automatically-vectorized version of fun, which can optionally be called with NamedArrays (or NamedArrayViews) instead of ordinary arrays, and which will always return NamedArrays (or NamedArrayViews) for each of its output leaves. Any argument (or PyTree leaf of an argument) that is a NamedArray(View) will have its named axes vectorized over; fun will then be called with batch tracers corresponding to slices of the input array that are shaped like named_array_arg.positional_shape. Every axis name that appeared in any input will also appear in every output.