NamedArrayBase

Contents

NamedArrayBase#

class penzai.core.named_axes.NamedArrayBase[source]#

Bases: ABC

Base class for named arrays and their transposed views.

Methods

all([axis, out, keepdims, where])

Name-vectorized version of array method all.

any([axis, out, keepdims, where])

Name-vectorized version of array method any.

argmax([axis, out, keepdims])

Name-vectorized version of array method argmax.

argmin([axis, out, keepdims])

Name-vectorized version of array method argmin.

argpartition(kth[, axis])

Name-vectorized version of array method argpartition.

argsort([axis, kind, order, stable, descending])

Name-vectorized version of array method argsort.

as_namedarrayview()

Converts into a NamedArrayView, keeping positional axes.

astype(dtype[, copy, device])

Name-vectorized version of array method astype.

broadcast_like(other)

Broadcasts a named array to be compatible with another.

broadcast_to([positional_shape, named_shape])

Broadcasts a named array to a possibly-larger shape.

canonicalize()

Ensures that the named axes are stored in a canonical order.

check_valid()

Checks that the names in the array are correct.

choose(choices[, out, mode])

Name-vectorized version of array method choose.

clip([min, max])

Name-vectorized version of array method clip.

compress(condition[, axis, out, size, ...])

Name-vectorized version of array method compress.

conj()

Name-vectorized version of array method conj.

conjugate()

Name-vectorized version of array method conjugate.

cumprod([axis, dtype, out])

Name-vectorized version of array method cumprod.

cumsum([axis, dtype, out])

Name-vectorized version of array method cumsum.

diagonal([offset, axis1, axis2])

Name-vectorized version of array method diagonal.

dot(b, *[, precision, preferred_element_type])

Name-vectorized version of array method dot.

flatten([order])

Name-vectorized version of array method flatten.

item(*args)

Name-vectorized version of array method item.

max([axis, out, keepdims, initial, where])

Name-vectorized version of array method max.

mean([axis, dtype, out, keepdims, where])

Name-vectorized version of array method mean.

min([axis, out, keepdims, initial, where])

Name-vectorized version of array method min.

nonzero(*[, size, fill_value])

Name-vectorized version of array method nonzero.

order_as(*axis_order)

Ensures that the named axes are stored in this order, keeping them named.

order_like(other)

Ensures that this array's PyTree structure matches another array's.

prod([axis, dtype, out, keepdims, initial, ...])

Name-vectorized version of array method prod.

ptp([axis, out, keepdims])

Name-vectorized version of array method ptp.

ravel([order])

Name-vectorized version of array method ravel.

repeat(repeats[, axis, total_repeat_length])

Name-vectorized version of array method repeat.

reshape(*args[, order])

Name-vectorized version of array method reshape.

round([decimals, out])

Name-vectorized version of array method round.

searchsorted(v[, side, sorter, method])

Name-vectorized version of array method searchsorted.

sort([axis, kind, order, stable, descending])

Name-vectorized version of array method sort.

squeeze([axis])

Name-vectorized version of array method squeeze.

std([axis, dtype, out, ddof, keepdims, ...])

Name-vectorized version of array method std.

sum([axis, dtype, out, keepdims, initial, ...])

Name-vectorized version of array method sum.

swapaxes(axis1, axis2)

Name-vectorized version of array method swapaxes.

tag(*names)

Attaches names to the positional axes of an array or view.

tag_prefix(*axis_order)

Attaches names to the first positional axes in an array or view.

take(indices[, axis, out, mode, ...])

Name-vectorized version of array method take.

trace([offset, axis1, axis2, dtype, out])

Name-vectorized version of array method trace.

transpose(*args)

Name-vectorized version of array method transpose.

untag(*axis_order)

Produces a positional view of the requested axis names.

untag_prefix(*axis_order)

Adds the requested axes to the front of the array's positional axes.

unwrap(*names)

Unwraps this array, possibly mapping axis names to positional axes.

var([axis, dtype, out, ddof, keepdims, ...])

Name-vectorized version of array method var.

view([dtype, type])

Name-vectorized version of array method view.

with_positional_prefix()

Ensures a view is a NamedArray by moving positional axes.

Attributes

T

Name-vectorized version of array method T.

at

Helper property for index update functionality.

dtype

The dtype of the wrapped array.

imag

Name-vectorized version of array method imag.

mT

Name-vectorized version of array method mT.

named_shape

A mapping of axis names to their sizes.

positional_shape

A tuple of axis sizes for any anonymous axes.

real

Name-vectorized version of array method real.

Inherited Methods

(expand to view inherited methods)

__init__()

T#

Name-vectorized version of array method T. Takes similar arguments as T but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__abs__()#

Name-vectorized version of jax.Array.__abs__. Takes similar arguments as jax.Array.__abs__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__add__(b, /)#

Name-vectorized version of jax.Array.__add__. Takes similar arguments as jax.Array.__add__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__and__(b, /)#

Name-vectorized version of jax.Array.__and__. Takes similar arguments as jax.Array.__and__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__divmod__(y, /)#

Name-vectorized version of jax.Array.__divmod__. Takes similar arguments as jax.Array.__divmod__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__floordiv__(b, /)#

Name-vectorized version of jax.Array.__floordiv__. Takes similar arguments as jax.Array.__floordiv__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__ge__(b, /)#

Name-vectorized version of jax.Array.__ge__. Takes similar arguments as jax.Array.__ge__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__getitem__(indexer) NamedArray | NamedArrayView[source]#

Retrieves slices from an indexer.

NamedArray and NamedArrayView can be indexed in two different ways, depending on whether the axes you wish to index are positional or named.

To index positional axes, you can use ordinary Numpy-style indexing. Indexing operations will be automatically vectorized over all of the named axes. For instance, an embedding lookup could look something like:

embedding_table.untag("vocab")[token_ids]

which first untags the “vocab” named axis as positional, then indexes into that axis using another array (which can be a NamedArray or an ordinary array).

The result of positional indexing follows the combination of nmap and Numpy indexing semantics: positional axis ordering is determined by Numpy basic/advanced indexing rules, and any named axes in the input or in the slices will be jointly vectorized over.

To index named axes, you can use a dictionary mapping axis names to indices or slices. For instance, you can use

my_array[{"position": 1, "feature": pz.slice[2:5]}]

Here pz.slice[2:5] is syntactic sugar for slice(2, 5, None).

The semantics of dict-style indexing are based on Numpy indexing rules, except that they apply to the named axes instead of positional axes. In general, dict-style indexing will behave like positional indexing, where the requested axes are mapped to positional axes, then indexed, then mapped back to named axes where applicable. For instance,

# Slice "foo" in place, and index into "bar" and "baz".
named_array[{"foo": pz.slice[2:5], "bar": 1, "baz": indexer_array}]

will behave like

# Slice the first axis, and index into the next two. Then restore the
# axis name "foo".
named_array.untag_prefix("foo", "bar", "baz")[2:5, 1, indexer_array, ...]
  .tag_prefix("foo")

Specifically:

  • Axis names that map to an integer will be indexed into, and those names will not appear in the output.

  • Axis names that map to a slice object (like pz.slice[2:5]) will be sliced into, and preserved in the output with a smaller size.

  • Axis names that map to None (or np.newaxis) must not appear in the input array. These axis names will be introduced and will have size 1.

  • Axis names that map to a Numpy or JAX array with non-empty (positional) shape and integer dtype will follow Numpy advanced indexing rules. All such arrays will be broadcast together and iterated over as one, and interpreted as a sequence of indices into each named axis. The result will have new positional axes at the front (matching the shapes of the advanced index arrays), followed by the existing positional axes of the input (if any).

  • Axis names that map to a Penzai named array will be vectorized over using nmap rules: those names will be vectorized over jointly with the array if they are present, and will be introduced into the result if they are not present.

The resulting array’s positional shape will first include any new axes introduced by advanced indexing, followed by the existing positional axes of the input array (if any). The array’s named shape will be the input array’s named shape, minus any axes that were indexed into (without using a slice), plus any names that were introduced using None/np.newaxis, plus the union of all names used in named array indices that aren’t already present.

Note that Numpy-style advanced indexing can be difficult to reason about due to the axis ordering of positional axes. We recommend indexing using either integers or NamedArrays with an empty positional shape, which will always introduce named axes instead of positional ones.

Parameters:

indexer – Either a normal Numpy-style indexer into the positional axes, or a mapping from a subset of axes names to the indices or slice it should return.

Returns:

A slice of the array.

__gt__(b, /)#

Name-vectorized version of jax.Array.__gt__. Takes similar arguments as jax.Array.__gt__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__invert__()#

Name-vectorized version of jax.Array.__invert__. Takes similar arguments as jax.Array.__invert__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__le__(b, /)#

Name-vectorized version of jax.Array.__le__. Takes similar arguments as jax.Array.__le__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__lshift__(b, /)#

Name-vectorized version of jax.Array.__lshift__. Takes similar arguments as jax.Array.__lshift__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__lt__(b, /)#

Name-vectorized version of jax.Array.__lt__. Takes similar arguments as jax.Array.__lt__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__mod__(b, /)#

Name-vectorized version of jax.Array.__mod__. Takes similar arguments as jax.Array.__mod__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__mul__(b, /)#

Name-vectorized version of jax.Array.__mul__. Takes similar arguments as jax.Array.__mul__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__ne__(b, /)#

Name-vectorized version of jax.Array.__ne__. Takes similar arguments as jax.Array.__ne__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__neg__()#

Name-vectorized version of jax.Array.__neg__. Takes similar arguments as jax.Array.__neg__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__or__(b, /)#

Name-vectorized version of jax.Array.__or__. Takes similar arguments as jax.Array.__or__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__pos__()#

Name-vectorized version of jax.Array.__pos__. Takes similar arguments as jax.Array.__pos__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__pow__(b, /)#

Name-vectorized version of jax.Array.__pow__. Takes similar arguments as jax.Array.__pow__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__radd__(y)[source]#

Name-vectorized version of jax.Array.__radd__. Takes similar arguments as jax.Array.__radd__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rand__(y)[source]#

Name-vectorized version of jax.Array.__rand__. Takes similar arguments as jax.Array.__rand__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rdivmod__(y)[source]#

Name-vectorized version of jax.Array.__rdivmod__. Takes similar arguments as jax.Array.__rdivmod__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rfloordiv__(y)[source]#

Name-vectorized version of jax.Array.__rfloordiv__. Takes similar arguments as jax.Array.__rfloordiv__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rlshift__(y)[source]#

Name-vectorized version of jax.Array.__rlshift__. Takes similar arguments as jax.Array.__rlshift__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rmod__(y)[source]#

Name-vectorized version of jax.Array.__rmod__. Takes similar arguments as jax.Array.__rmod__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rmul__(y)[source]#

Name-vectorized version of jax.Array.__rmul__. Takes similar arguments as jax.Array.__rmul__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__ror__(y)[source]#

Name-vectorized version of jax.Array.__ror__. Takes similar arguments as jax.Array.__ror__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rpow__(y)[source]#

Name-vectorized version of jax.Array.__rpow__. Takes similar arguments as jax.Array.__rpow__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rrshift__(y)[source]#

Name-vectorized version of jax.Array.__rrshift__. Takes similar arguments as jax.Array.__rrshift__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rshift__(b, /)#

Name-vectorized version of jax.Array.__rshift__. Takes similar arguments as jax.Array.__rshift__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rsub__(y)[source]#

Name-vectorized version of jax.Array.__rsub__. Takes similar arguments as jax.Array.__rsub__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rtruediv__(y)[source]#

Name-vectorized version of jax.Array.__rtruediv__. Takes similar arguments as jax.Array.__rtruediv__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__rxor__(y)[source]#

Name-vectorized version of jax.Array.__rxor__. Takes similar arguments as jax.Array.__rxor__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__sub__(b, /)#

Name-vectorized version of jax.Array.__sub__. Takes similar arguments as jax.Array.__sub__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__treescope_ndarray_adapter__()[source]#

Treescope handler for named arrays.

__treescope_repr__(path: str | None, subtree_renderer: Any)[source]#

Treescope handler for named arrays.

__truediv__(b, /)#

Name-vectorized version of jax.Array.__truediv__. Takes similar arguments as jax.Array.__truediv__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

__xor__(b, /)#

Name-vectorized version of jax.Array.__xor__. Takes similar arguments as jax.Array.__xor__ but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

all(axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) Array[source]#

Name-vectorized version of array method all. Takes similar arguments as all but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

any(axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) Array[source]#

Name-vectorized version of array method any. Takes similar arguments as any but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

argmax(axis: int | None = None, out: None = None, keepdims: bool | None = None) Array[source]#

Name-vectorized version of array method argmax. Takes similar arguments as argmax but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

argmin(axis: int | None = None, out: None = None, keepdims: bool | None = None) Array[source]#

Name-vectorized version of array method argmin. Takes similar arguments as argmin but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

argpartition(kth: int, axis: int = -1) Array[source]#

Name-vectorized version of array method argpartition. Takes similar arguments as argpartition but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

argsort(axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) Array[source]#

Name-vectorized version of array method argsort. Takes similar arguments as argsort but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

abstract as_namedarrayview() NamedArrayView[source]#

Converts into a NamedArrayView, keeping positional axes.

This function is usually not necessary for ordinary named-array manipulation, since NamedArray and NamedArrayView define the same methods. However, it can be useful for simplifying library code that wishes to access the fields of NamedArrayView directly, or handle arbitrary named array objects without handling each case separately.

Converting a NamedArray to a NamedArrayView never involves any device computations. (The reverse is not true).

Returns:

An equivalent NamedArrayView for this array if it isn’t one already.

astype(dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) Array[source]#

Name-vectorized version of array method astype. Takes similar arguments as astype but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

property at: _IndexUpdateHelper#

Helper property for index update functionality.

Lifts the jax.Array.at[...] syntax to also work for Penzai NamedArrays.

Similar to direct indexing, NamedArray.at[...] can be indexed in two different ways, depending on whether the axes you wish to index are positional or named.

When indexing positional axes, this operation follows the combination of nmap and jax.Array.at[...] semantics. In particular,

named_array.at[index].set(value)

is equivalent to

nmap(lambda arr, i, v: arr.at[i].set(v))(named_array, index, value)

The resulting array will have the same positional shape as the input array, and will have a named shape that is the union of the named shape of the input array and the names used in the indexer.

The semantics of dict-style indexing are similar, except that they apply to the named axes instead of the positional ones. In general, dict-style indexing will behave like positional indexing, where the requested axes are mapped to positional axes, then indexed, then mapped back to named axes where applicable. In other words,

# Update part of "foo" in place, and index into "bar" and "baz".
named_array.at[{
    "foo": pz.slice[2:5], "bar": 1, "baz": indexer_array
}].set(value)

will behave like

result_structure = jax.eval_shape(lambda: named_array[{
    "foo": pz.slice[2:5], "bar": 1, "baz": indexer_array
}])

# Update positionally, but expect the name "foo" in `value`, and make
# sure it maps correctly to "foo" in the input.
named_array.untag_prefix("foo", "bar", "baz")
  .at[2:5, 1, indexer_array, ...]
  .set(value.broadcast_like(result_structure).untag_prefix("foo"))
  .tag_prefix("foo", "bar", "baz")

Specifically:

  • Axis names that are sliced by index_dict (e.g. mapped to a slice object) can appear in the value. These will be used to update the corresponding slices of the array.

  • Axis names that do not appear in index_dict will be broadcast against the corresponding axis names in value if they exist, following the semantics of nmap.

  • The positional axes of value will be broadcast against the result of slicing the input array. This means that the suffix axes of value will correspond to positional axes of the input array, and the prefix axes will correspond to new axes introduced by advanced Numpy indexing. Often, the input named_array will have no positional axes, in which case the positional axes of value will be broadcast against the positional axes of the indexer arrays.

Note that, in order to update multiple positions along the same axis, you will need to use Numpy-style advanced indexing, by indexing with an array with a positional axis. For instance, to update indices 2 and 4 along axis “foo”, you can do

# Option 1: dict-style
named_array.at[{ "foo": jnp.array([2, 4]) }].set( jnp.array([100, 101]) )

# Option 2: positional-style
named_array.untag("foo").at[jnp.array([2, 4])]
  .set(jnp.array([100, 101])).tag("foo")

The following, will instead create a new axis “bar” with one update in each dimension:

# Option 1: dict-style
named_array.at[{
    "foo": pz.nx.wrap(jnp.array([2, 4])).tag("bar")
}].set(
    pz.nx.wrap(jnp.array([100, 101])).tag("bar")
)

# Option 2: positional-style
named_array.untag("foo").at[
    pz.nx.wrap(jnp.array([2, 4])).tag("bar")
].set(
    pz.nx.wrap(jnp.array([100, 101])).tag("bar")
).tag("foo")

The reason for this difference is that the named axis “bar” is vectorized over, and the behavior needs to be consistent regardless of whether “bar” appears in named_array or not.

broadcast_like(other: NamedArrayBase | jax.typing.ArrayLike) NamedArrayBase[source]#

Broadcasts a named array to be compatible with another.

Parameters:

other – Another named array.

Returns:

A named array that has the same positional and named shapes as other (although it may also include extra named axes).

broadcast_to(positional_shape: Sequence[int] = (), named_shape: Mapping[AxisName, int] | None = None) NamedArrayBase[source]#

Broadcasts a named array to a possibly-larger shape.

Parameters:
  • positional_shape – Desired positional shape for the array. Will be broadcast using numpy broadcasting rules.

  • named_shape – Desired named shape for the array. Will be broadcast using nmap-style vectorizing rules (e.g. new named axes will be introduced if missing, but length-1 axes will not be broadcast).

Returns:

A named array that has the given positional and named shapes. Note that if this array has axis names that are not in named_shape, these will be preserved in the answer as well.

canonicalize() NamedArray[source]#

Ensures that the named axes are stored in a canonical order.

Returns:

Equivalent NamedArray whose data array contains the positional axes followed by the named axes in sorted order.

abstract check_valid() None[source]#

Checks that the names in the array are correct.

choose(choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') Array[source]#

Name-vectorized version of array method choose. Takes similar arguments as choose but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

clip(min: ArrayLike | None = None, max: ArrayLike | None = None) Array[source]#

Name-vectorized version of array method clip. Takes similar arguments as clip but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

compress(condition: ArrayLike, axis: int | None = None, *, out: None = None, size: int | None = None, fill_value: ArrayLike = 0) Array[source]#

Name-vectorized version of array method compress. Takes similar arguments as compress but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

conj() Array[source]#

Name-vectorized version of array method conj. Takes similar arguments as conj but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

conjugate() Array[source]#

Name-vectorized version of array method conjugate. Takes similar arguments as conjugate but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

cumprod(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) Array[source]#

Name-vectorized version of array method cumprod. Takes similar arguments as cumprod but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

cumsum(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) Array[source]#

Name-vectorized version of array method cumsum. Takes similar arguments as cumsum but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

diagonal(offset: int = 0, axis1: int = 0, axis2: int = 1) Array[source]#

Name-vectorized version of array method diagonal. Takes similar arguments as diagonal but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

dot(b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) Array[source]#

Name-vectorized version of array method dot. Takes similar arguments as dot but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

abstract property dtype: np.dtype#

The dtype of the wrapped array.

flatten(order: str = 'C') Array[source]#

Name-vectorized version of array method flatten. Takes similar arguments as flatten but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

imag#

Name-vectorized version of array method imag. Takes similar arguments as imag but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

item(*args) bool | int | float | complex[source]#

Name-vectorized version of array method item. Takes similar arguments as item but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

mT#

Name-vectorized version of array method mT. Takes similar arguments as mT but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

max(axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) Array[source]#

Name-vectorized version of array method max. Takes similar arguments as max but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

mean(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) Array[source]#

Name-vectorized version of array method mean. Takes similar arguments as mean but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

min(axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) Array[source]#

Name-vectorized version of array method min. Takes similar arguments as min but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

abstract property named_shape: Mapping[AxisName, int]#

A mapping of axis names to their sizes.

nonzero(*, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) tuple[Array, ...][source]#

Name-vectorized version of array method nonzero. Takes similar arguments as nonzero but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

order_as(*axis_order: AxisName) NamedArray[source]#

Ensures that the named axes are stored in this order, keeping them named.

This function can be used if it is important for the axis names to appear in a consistent order, e.g. to ensure that two NamedArray instances have exactly the same PyTree structure.

If you want a canonical ordering for a named array that doesn’t involve knowing all the axis names in advance, you could do something like array.order_as(*sorted(array.named_shape.keys())).

See also order_like.

Parameters:

*axis_order – Axis names in the order they should appear in the data array. Must be a permutation of all of the axis names in this array.

Returns:

Equivalent NamedArray whose data array contains the positional axes followed by the named axes in the given order.

order_like(other: NamedArray | NamedArrayView) NamedArray | NamedArrayView[source]#

Ensures that this array’s PyTree structure matches another array’s.

This can be used to ensure that one named array has the same PyTree structure as another, so that the two can be jointly processed by non-namedarray-aware tree functions (e.g. jax.tree_util functions, jax.lax.cond, jax.jvp, etc).

To ensure compatibility of entire PyTrees, you can use something like:

jax.tree_util.tree_map(
    lambda a, b: a.order_like(b), tree1, tree2,
    is_leaf=pz.nx.is_namedarray,
)
Parameters:

other – Another named array or named array view. Must have the same set of named axes as self. If other is a NamedArrayView, other must also have the same number of positional axes.

Returns:

A new NamedArray or NamedArrayView that has the content of self but is possibly transposed to have the axes appear in the same order as other in the data array. If the arrays have the same named and positional shapes, the result will have the same PyTree structure as other.

abstract property positional_shape: tuple[int, ...]#

A tuple of axis sizes for any anonymous axes.

prod(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) Array[source]#

Name-vectorized version of array method prod. Takes similar arguments as prod but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

ptp(axis: Axis = None, out: None = None, keepdims: bool = False) Array[source]#

Name-vectorized version of array method ptp. Takes similar arguments as ptp but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

ravel(order: str = 'C') Array[source]#

Name-vectorized version of array method ravel. Takes similar arguments as ravel but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

real#

Name-vectorized version of array method real. Takes similar arguments as real but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

repeat(repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) Array[source]#

Name-vectorized version of array method repeat. Takes similar arguments as repeat but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

reshape(*args: Any, order: str = 'C') Array[source]#

Name-vectorized version of array method reshape. Takes similar arguments as reshape but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

round(decimals: int = 0, out: None = None) Array[source]#

Name-vectorized version of array method round. Takes similar arguments as round but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

searchsorted(v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') Array[source]#

Name-vectorized version of array method searchsorted. Takes similar arguments as searchsorted but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

sort(axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) Array[source]#

Name-vectorized version of array method sort. Takes similar arguments as sort but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

squeeze(axis: int | Sequence[int] | None = None) Array[source]#

Name-vectorized version of array method squeeze. Takes similar arguments as squeeze but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

std(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) Array[source]#

Name-vectorized version of array method std. Takes similar arguments as std but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

sum(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) Array[source]#

Name-vectorized version of array method sum. Takes similar arguments as sum but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

swapaxes(axis1: int, axis2: int) Array[source]#

Name-vectorized version of array method swapaxes. Takes similar arguments as swapaxes but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

abstract tag(*names) NamedArray[source]#

Attaches names to the positional axes of an array or view.

Parameters:

*names – Axis names to assign to each positional axis in the array or view. Must be the same length as positional_shape; if you only want to tag a subset of axes, use tag_prefix instead.

Raises:

ValueError – If the names are invalid, or if they aren’t the same length as positional_shape.

Returns:

A NamedArray with the given names assigned to the positional axes, and no remaining positional axes.

tag_prefix(*axis_order: AxisName) NamedArray | NamedArrayView[source]#

Attaches names to the first positional axes in an array or view.

This is a version of tag that allows you to name only a subset of the array’s positional axes.

Parameters:

*axis_order – Axis names to make positional, in the order they should appear in the positional view.

Returns:

A NamedArray or view with the given names assigned to the first positional axes, and whose positional shape includes only the suffix of axes that have not been given names.

take(indices: ArrayLike, axis: int | None = None, out: None = None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value: StaticScalar | None = None) Array[source]#

Name-vectorized version of array method take. Takes similar arguments as take but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

trace(offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) Array[source]#

Name-vectorized version of array method trace. Takes similar arguments as trace but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

transpose(*args: Any) Array[source]#

Name-vectorized version of array method transpose. Takes similar arguments as transpose but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

abstract untag(*axis_order: AxisName) NamedArray | NamedArrayView[source]#

Produces a positional view of the requested axis names.

untag can only be called on a NamedArray or NamedArrayView that does not have any positional axes. It produces a new NamedArrayView where the axes with the requested names (the arguments to this function) are now treated as positional in the given order.

If you want to use untag on an array that already has positional axes, you can use untag_prefix instead.

Parameters:

*axis_order – Axis names to make positional, in the order they should appear in the positional view.

Raises:

ValueError – If this array already has positional axes, or if the provided axis ordering is not valid.

Returns:

A view with the given axes treated as positional for the purposes of later calls to apply, nmap, or with_positional_prefix. If passed an empty axis order, returns an ordinary NamedArray with no positional axes.

untag_prefix(*axis_order: AxisName) NamedArray | NamedArrayView[source]#

Adds the requested axes to the front of the array’s positional axes.

This is a version of untag that can be called on NamedArrays or NamedArrayViews that already have positional axes.

Parameters:

*axis_order – Axis names to make positional, in the order they should appear in the positional view.

Returns:

A view with the given axes treated as positional, followed by the existing positional axes.

abstract unwrap(*names: AxisName) jax.Array[source]#

Unwraps this array, possibly mapping axis names to positional axes.

Unwrap can be called either on arrays with no named axes, or arrays with no positional axes (in which case names should be a permutation of its axis names).

Parameters:

*names – Sequence of axis names to map to positional axes, if this array has named axes. Shortand for untag(*names).unwrap().

Returns:

An equivalent ordinary positional array.

Raises:

ValueError – If the array has a mixture of positional and named axes, or if the names do not match the named axes.

var(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) Array[source]#

Name-vectorized version of array method var. Takes similar arguments as var but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

view(dtype: DTypeLike | None = None, type: None = None) Array[source]#

Name-vectorized version of array method view. Takes similar arguments as view but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.

abstract with_positional_prefix() NamedArray[source]#

Ensures a view is a NamedArray by moving positional axes.

The resulting NamedArray has the same named and positional shapes as this object, but the data array may be transposed so that all the positional axes are in the front. This makes it possible to manipulate those prefix axes safely using jax.tree_util or scan/map over them using JAX control flow primitives.

Returns:

An equivalent NamedArray for this view, or the original NamedArray if it already was one.