nmap#
- penzai.core.named_axes.nmap(fun: Callable[..., Any]) Callable[..., Any] [source]#
Automatically vectorizes
fun
over named axes of NamedArray inputs.nmap
is a “named-axis vectorizing map”. It wraps an ordinary positional-axis-based function so that it accepts NamedArrays as input and produces NamedArrays as output, and vectorizes over all of the named axes, calling the original function with positionally-indexed slices corresponding to each argument’spositional_shape
.Unlike
jax.vmap
, the axes to vectorize over are inferred automatically from the named axes in the NamedArray / NamedArrayView, rather than being specified as part of the mapping transformation. Specifically, each axis name that appears in any of the arguments is vectorized over jointly across all arguments that include that axis name, and is then included as an axis name in the output. To make an axis visible tofun
, you can calluntag
on the argument and pass the axis name(s) of interest;fun
will then see those axes as positional axes instead of mapping over them.untag
andnmap
are together the primary ways to apply individual operations to axes of a NamedArray.tag
can then be used on the result to re-bind names to positional axes.Within
fun
, any mapped-over axes will be accessible using standard JAX collective operations likepsum
, although doing this is usually unnecessary.- Parameters:
fun – Function to vectorize by name. This can take arbitrary arguments (even non-JAX-arraylike arguments or “static” axis sizes), but must produce a PyTree of JAX ArrayLike outputs.
- Returns:
An automatically-vectorized version of
fun
, which can optionally be called with NamedArrays (or NamedArrayViews) instead of ordinary arrays, and which will always return NamedArrays (or NamedArrayViews) for each of its output leaves. Any argument (or PyTree leaf of an argument) that is a NamedArray(View) will have its named axes vectorized over;fun
will then be called with batch tracers corresponding to slices of the input array that are shaped likenamed_array_arg.positional_shape
. Every axis name that appeared in any input will also appear in every output.