NamedArray

NamedArray#

class penzai.core.named_axes.NamedArray[source]#

Bases: NamedArrayBase, Struct

A multidimensional array with a combination of positional and named axes.

Conceptually, NamedArrays can have positional axes like an ordinary jax.Array, but can also have explicit named axes. Operations on NamedArrays always act only on the positional axes, and are vectorized (or “lifted”) over the named axes. To apply operations to a named axis, you can first untag that named axis as a positional axis, then apply the operation as normal. (This is intentional to avoid having to re-define separate “named” versions of every JAX or Numpy function.)

Internally, a NamedArray stores its array data in the data_array attribute. The positional axes always appear as a prefix of the data array’s shape; this means stacking NamedArrays along their first axis (e.g. axis=0) or iterating over this axis (e.g. with jax.lax.scan) will work correctly for NamedArrays with a nonempty positional shape. However, to avoid unnecessary transpositions, some operations on a NamedArray will produce a NamedArrayView instead of a NamedArray. (NamedArrayViews define the same methods as NamedArrays but have a more complex data representation.)

Operations on NamedArrays generally involve constructing positional views of he axes you need to operate on:

  • To run primitive operations (like jnp.sum or jax.nn.softmax) on a NamedArray, you can call untag to mark specific axes as positional, then use named_axes.nmap (or wrapped instance methods like .sum) to run the primitive positional operation over that locally-positional view. Any positional axes produced by the operation can be rebound to names using tag.

  • To slice an axis out of a NamedArray (e.g. as input to jax.lax.scan), you can first move the given axis names to the front of the array using .untag(...).with_positional_prefix(), then do a tree_map over the internal data_array and slice the first axis (which is what jax.lax.scan does internally).

  • To stack NamedArrays together along an axis (e.g. for the output of jax.lax.scan), you can just stack them normally, then use tag to give a name to the new axis. (If you want to stack NamedArrayView`s, convert them to ``NamedArray`s first using with_positional_prefix.)

Note that it’s only safe to manipulate the prefix axes of data_array which do not have names. Any operations on the named axes should first assign positions using untag, and then either do the operation inside nmap or move those positions to the prefix of data_array using with_positional_prefix.

The internal ordering of the named axes is part of the PyTree structure of a NamedArray or NamedArrayView, which means passing them through JAX operations sometimes requires re-ordering the named axes to ensure a consistent structure. To make one named array have the same PyTree structure as another, you can use first.order_like(second) or, for trees,

jax.tree_util.tree_map(
    lambda a, b: a.order_like(b), tree1, tree2,
    is_leaf=pz.nx.is_namedarray,
)

You can also use canonicalize to reorder the internal axes in a canonical order; any named arrays with equivalent shapes will have the same PyTree structure after calling canonicalize.

Variables:
  • named_axes (collections.OrderedDict[AxisName, int]) – An ordered map from axis names to their lengths. The values must be a suffix of data_array.shape, and the keys give names to each of the suffix dimensions of the array. Usually, this will be the same length as data_array.shape.

  • data_array (jax.Array) – The underlying positional-indexed array.

Inherited Attributes

T

Name-vectorized version of array method T.

at

Helper property for index update functionality.

imag

Name-vectorized version of array method imag.

mT

Name-vectorized version of array method mT.

real

Name-vectorized version of array method real.

Methods

__init__(named_axes, data_array)

as_namedarrayview()

check_valid()

tag(*names)

untag(*axis_order)

unwrap(*names)

with_positional_prefix()

wrap(array, *names)

Wraps a positional array as a NamedArray.

Attributes

T

Name-vectorized version of array method T.

at

Helper property for index update functionality.

dtype

imag

Name-vectorized version of array method imag.

mT

Name-vectorized version of array method mT.

named_shape

positional_shape

real

Name-vectorized version of array method real.

named_axes

data_array

Inherited Methods

(expand to view inherited methods)

all([axis, out, keepdims, where])

Name-vectorized version of array method all.

any([axis, out, keepdims, where])

Name-vectorized version of array method any.

argmax([axis, out, keepdims])

Name-vectorized version of array method argmax.

argmin([axis, out, keepdims])

Name-vectorized version of array method argmin.

argpartition(kth[, axis])

Name-vectorized version of array method argpartition.

argsort([axis, kind, order, stable, descending])

Name-vectorized version of array method argsort.

astype(dtype[, copy, device])

Name-vectorized version of array method astype.

attributes_dict()

Constructs a dictionary with all of the fields in the class.

broadcast_like(other)

Broadcasts a named array to be compatible with another.

broadcast_to([positional_shape, named_shape])

Broadcasts a named array to a possibly-larger shape.

canonicalize()

Ensures that the named axes are stored in a canonical order.

choose(choices[, out, mode])

Name-vectorized version of array method choose.

clip([min, max])

Name-vectorized version of array method clip.

compress(condition[, axis, out, size, ...])

Name-vectorized version of array method compress.

conj()

Name-vectorized version of array method conj.

conjugate()

Name-vectorized version of array method conjugate.

cumprod([axis, dtype, out])

Name-vectorized version of array method cumprod.

cumsum([axis, dtype, out])

Name-vectorized version of array method cumsum.

diagonal([offset, axis1, axis2])

Name-vectorized version of array method diagonal.

dot(b, *[, precision, preferred_element_type])

Name-vectorized version of array method dot.

flatten([order])

Name-vectorized version of array method flatten.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

item(*args)

Name-vectorized version of array method item.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

max([axis, out, keepdims, initial, where])

Name-vectorized version of array method max.

mean([axis, dtype, out, keepdims, where])

Name-vectorized version of array method mean.

min([axis, out, keepdims, initial, where])

Name-vectorized version of array method min.

nonzero(*[, size, fill_value])

Name-vectorized version of array method nonzero.

order_as(*axis_order)

Ensures that the named axes are stored in this order, keeping them named.

order_like(other)

Ensures that this array's PyTree structure matches another array's.

prod([axis, dtype, out, keepdims, initial, ...])

Name-vectorized version of array method prod.

ptp([axis, out, keepdims])

Name-vectorized version of array method ptp.

ravel([order])

Name-vectorized version of array method ravel.

repeat(repeats[, axis, total_repeat_length])

Name-vectorized version of array method repeat.

reshape(*args[, order])

Name-vectorized version of array method reshape.

round([decimals, out])

Name-vectorized version of array method round.

searchsorted(v[, side, sorter, method])

Name-vectorized version of array method searchsorted.

select()

Wraps this struct in a selection, enabling functional-style mutations.

sort([axis, kind, order, stable, descending])

Name-vectorized version of array method sort.

squeeze([axis])

Name-vectorized version of array method squeeze.

std([axis, dtype, out, ddof, keepdims, ...])

Name-vectorized version of array method std.

sum([axis, dtype, out, keepdims, initial, ...])

Name-vectorized version of array method sum.

swapaxes(axis1, axis2)

Name-vectorized version of array method swapaxes.

tag_prefix(*axis_order)

Attaches names to the first positional axes in an array or view.

take(indices[, axis, out, mode, ...])

Name-vectorized version of array method take.

trace([offset, axis1, axis2, dtype, out])

Name-vectorized version of array method trace.

transpose(*args)

Name-vectorized version of array method transpose.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

untag_prefix(*axis_order)

Adds the requested axes to the front of the array's positional axes.

var([axis, dtype, out, ddof, keepdims, ...])

Name-vectorized version of array method var.

view([dtype, type])

Name-vectorized version of array method view.

classmethod wrap(array: jax.typing.ArrayLike, *names: AxisName) NamedArray[source]#

Wraps a positional array as a NamedArray.

Parameters:
  • array – Array to wrap.

  • *names – Optional names for the axes of the array. If provided, must be the same length as the array’s shape. This is a convenience wrapper so that you can call wrap(array, "foo", "bar") instead of wrap(array).tag("foo", "bar").

Returns:

An equivalent NamedArray for the given array. If names is provided, the resulting array will have those names assigned to the corresponding axes. Otherwise, the resulting array will have a positional shape.