NamedArrayBase#
- class penzai.core.named_axes.NamedArrayBase[source]#
Bases:
ABC
Base class for named arrays and their transposed views.
Methods
all
([axis, out, keepdims, where])Name-vectorized version of array method
all
.any
([axis, out, keepdims, where])Name-vectorized version of array method
any
.argmax
([axis, out, keepdims])Name-vectorized version of array method
argmax
.argmin
([axis, out, keepdims])Name-vectorized version of array method
argmin
.argpartition
(kth[, axis])Name-vectorized version of array method
argpartition
.argsort
([axis, kind, order, stable, descending])Name-vectorized version of array method
argsort
.Converts into a
NamedArrayView
, keeping positional axes.astype
(dtype[, copy, device])Name-vectorized version of array method
astype
.broadcast_like
(other)Broadcasts a named array to be compatible with another.
broadcast_to
([positional_shape, named_shape])Broadcasts a named array to a possibly-larger shape.
Ensures that the named axes are stored in a canonical order.
Checks that the names in the array are correct.
choose
(choices[, out, mode])Name-vectorized version of array method
choose
.clip
([min, max])Name-vectorized version of array method
clip
.compress
(condition[, axis, out, size, ...])Name-vectorized version of array method
compress
.conj
()Name-vectorized version of array method
conj
.Name-vectorized version of array method
conjugate
.cumprod
([axis, dtype, out])Name-vectorized version of array method
cumprod
.cumsum
([axis, dtype, out])Name-vectorized version of array method
cumsum
.diagonal
([offset, axis1, axis2])Name-vectorized version of array method
diagonal
.dot
(b, *[, precision, preferred_element_type])Name-vectorized version of array method
dot
.flatten
([order])Name-vectorized version of array method
flatten
.item
(*args)Name-vectorized version of array method
item
.max
([axis, out, keepdims, initial, where])Name-vectorized version of array method
max
.mean
([axis, dtype, out, keepdims, where])Name-vectorized version of array method
mean
.min
([axis, out, keepdims, initial, where])Name-vectorized version of array method
min
.nonzero
(*[, size, fill_value])Name-vectorized version of array method
nonzero
.order_as
(*axis_order)Ensures that the named axes are stored in this order, keeping them named.
order_like
(other)Ensures that this array's PyTree structure matches another array's.
prod
([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
prod
.ptp
([axis, out, keepdims])Name-vectorized version of array method
ptp
.ravel
([order])Name-vectorized version of array method
ravel
.repeat
(repeats[, axis, total_repeat_length])Name-vectorized version of array method
repeat
.reshape
(*args[, order])Name-vectorized version of array method
reshape
.round
([decimals, out])Name-vectorized version of array method
round
.searchsorted
(v[, side, sorter, method])Name-vectorized version of array method
searchsorted
.sort
([axis, kind, order, stable, descending])Name-vectorized version of array method
sort
.squeeze
([axis])Name-vectorized version of array method
squeeze
.std
([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
std
.sum
([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
sum
.swapaxes
(axis1, axis2)Name-vectorized version of array method
swapaxes
.tag
(*names)Attaches names to the positional axes of an array or view.
tag_prefix
(*axis_order)Attaches names to the first positional axes in an array or view.
take
(indices[, axis, out, mode, ...])Name-vectorized version of array method
take
.trace
([offset, axis1, axis2, dtype, out])Name-vectorized version of array method
trace
.transpose
(*args)Name-vectorized version of array method
transpose
.untag
(*axis_order)Produces a positional view of the requested axis names.
untag_prefix
(*axis_order)Adds the requested axes to the front of the array's positional axes.
unwrap
(*names)Unwraps this array, possibly mapping axis names to positional axes.
var
([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
var
.view
([dtype, type])Name-vectorized version of array method
view
.Ensures a view is a
NamedArray
by moving positional axes.Attributes
Name-vectorized version of array method
T
.Helper property for index update functionality.
The dtype of the wrapped array.
Name-vectorized version of array method
imag
.Name-vectorized version of array method
mT
.A mapping of axis names to their sizes.
A tuple of axis sizes for any anonymous axes.
Name-vectorized version of array method
real
.Inherited Methods
(expand to view inherited methods)
__init__
()- T#
Name-vectorized version of array method
T
. Takes similar arguments asT
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __abs__()#
Name-vectorized version of
jax.Array.__abs__
. Takes similar arguments asjax.Array.__abs__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __add__(b, /)#
Name-vectorized version of
jax.Array.__add__
. Takes similar arguments asjax.Array.__add__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __and__(b, /)#
Name-vectorized version of
jax.Array.__and__
. Takes similar arguments asjax.Array.__and__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __divmod__(y, /)#
Name-vectorized version of
jax.Array.__divmod__
. Takes similar arguments asjax.Array.__divmod__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __floordiv__(b, /)#
Name-vectorized version of
jax.Array.__floordiv__
. Takes similar arguments asjax.Array.__floordiv__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __ge__(b, /)#
Name-vectorized version of
jax.Array.__ge__
. Takes similar arguments asjax.Array.__ge__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __getitem__(indexer) NamedArray | NamedArrayView [source]#
Retrieves slices from an indexer.
NamedArray
andNamedArrayView
can be indexed in two different ways, depending on whether the axes you wish to index are positional or named.To index positional axes, you can use ordinary Numpy-style indexing. Indexing operations will be automatically vectorized over all of the named axes. For instance, an embedding lookup could look something like:
embedding_table.untag("vocab")[token_ids]
which first untags the “vocab” named axis as positional, then indexes into that axis using another array (which can be a
NamedArray
or an ordinary array).The result of positional indexing follows the combination of
nmap
and Numpy indexing semantics: positional axis ordering is determined by Numpy basic/advanced indexing rules, and any named axes in the input or in the slices will be jointly vectorized over.To index named axes, you can use a dictionary mapping axis names to indices or slices. For instance, you can use
my_array[{"position": 1, "feature": pz.slice[2:5]}]
Here
pz.slice[2:5]
is syntactic sugar forslice(2, 5, None)
.The semantics of dict-style indexing are based on Numpy indexing rules, except that they apply to the named axes instead of positional axes. In general, dict-style indexing will behave like positional indexing, where the requested axes are mapped to positional axes, then indexed, then mapped back to named axes where applicable. For instance,
# Slice "foo" in place, and index into "bar" and "baz". named_array[{"foo": pz.slice[2:5], "bar": 1, "baz": indexer_array}]
will behave like
# Slice the first axis, and index into the next two. Then restore the # axis name "foo". named_array.untag_prefix("foo", "bar", "baz")[2:5, 1, indexer_array, ...] .tag_prefix("foo")
Specifically:
Axis names that map to an integer will be indexed into, and those names will not appear in the output.
Axis names that map to a slice object (like
pz.slice[2:5]
) will be sliced into, and preserved in the output with a smaller size.Axis names that map to None (or
np.newaxis
) must not appear in the input array. These axis names will be introduced and will have size 1.Axis names that map to a Numpy or JAX array with non-empty (positional) shape and integer dtype will follow Numpy advanced indexing rules. All such arrays will be broadcast together and iterated over as one, and interpreted as a sequence of indices into each named axis. The result will have new positional axes at the front (matching the shapes of the advanced index arrays), followed by the existing positional axes of the input (if any).
Axis names that map to a Penzai named array will be vectorized over using
nmap
rules: those names will be vectorized over jointly with the array if they are present, and will be introduced into the result if they are not present.
The resulting array’s positional shape will first include any new axes introduced by advanced indexing, followed by the existing positional axes of the input array (if any). The array’s named shape will be the input array’s named shape, minus any axes that were indexed into (without using a slice), plus any names that were introduced using
None
/np.newaxis
, plus the union of all names used in named array indices that aren’t already present.Note that Numpy-style advanced indexing can be difficult to reason about due to the axis ordering of positional axes. We recommend indexing using either integers or NamedArrays with an empty positional shape, which will always introduce named axes instead of positional ones.
- Parameters:
indexer – Either a normal Numpy-style indexer into the positional axes, or a mapping from a subset of axes names to the indices or slice it should return.
- Returns:
A slice of the array.
- __gt__(b, /)#
Name-vectorized version of
jax.Array.__gt__
. Takes similar arguments asjax.Array.__gt__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __invert__()#
Name-vectorized version of
jax.Array.__invert__
. Takes similar arguments asjax.Array.__invert__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __le__(b, /)#
Name-vectorized version of
jax.Array.__le__
. Takes similar arguments asjax.Array.__le__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __lshift__(b, /)#
Name-vectorized version of
jax.Array.__lshift__
. Takes similar arguments asjax.Array.__lshift__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __lt__(b, /)#
Name-vectorized version of
jax.Array.__lt__
. Takes similar arguments asjax.Array.__lt__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __mod__(b, /)#
Name-vectorized version of
jax.Array.__mod__
. Takes similar arguments asjax.Array.__mod__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __mul__(b, /)#
Name-vectorized version of
jax.Array.__mul__
. Takes similar arguments asjax.Array.__mul__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __ne__(b, /)#
Name-vectorized version of
jax.Array.__ne__
. Takes similar arguments asjax.Array.__ne__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __neg__()#
Name-vectorized version of
jax.Array.__neg__
. Takes similar arguments asjax.Array.__neg__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __or__(b, /)#
Name-vectorized version of
jax.Array.__or__
. Takes similar arguments asjax.Array.__or__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __pos__()#
Name-vectorized version of
jax.Array.__pos__
. Takes similar arguments asjax.Array.__pos__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __pow__(b, /)#
Name-vectorized version of
jax.Array.__pow__
. Takes similar arguments asjax.Array.__pow__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __radd__(y)[source]#
Name-vectorized version of
jax.Array.__radd__
. Takes similar arguments asjax.Array.__radd__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rand__(y)[source]#
Name-vectorized version of
jax.Array.__rand__
. Takes similar arguments asjax.Array.__rand__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rdivmod__(y)[source]#
Name-vectorized version of
jax.Array.__rdivmod__
. Takes similar arguments asjax.Array.__rdivmod__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rfloordiv__(y)[source]#
Name-vectorized version of
jax.Array.__rfloordiv__
. Takes similar arguments asjax.Array.__rfloordiv__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rlshift__(y)[source]#
Name-vectorized version of
jax.Array.__rlshift__
. Takes similar arguments asjax.Array.__rlshift__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rmod__(y)[source]#
Name-vectorized version of
jax.Array.__rmod__
. Takes similar arguments asjax.Array.__rmod__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rmul__(y)[source]#
Name-vectorized version of
jax.Array.__rmul__
. Takes similar arguments asjax.Array.__rmul__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __ror__(y)[source]#
Name-vectorized version of
jax.Array.__ror__
. Takes similar arguments asjax.Array.__ror__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rpow__(y)[source]#
Name-vectorized version of
jax.Array.__rpow__
. Takes similar arguments asjax.Array.__rpow__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rrshift__(y)[source]#
Name-vectorized version of
jax.Array.__rrshift__
. Takes similar arguments asjax.Array.__rrshift__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rshift__(b, /)#
Name-vectorized version of
jax.Array.__rshift__
. Takes similar arguments asjax.Array.__rshift__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rsub__(y)[source]#
Name-vectorized version of
jax.Array.__rsub__
. Takes similar arguments asjax.Array.__rsub__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rtruediv__(y)[source]#
Name-vectorized version of
jax.Array.__rtruediv__
. Takes similar arguments asjax.Array.__rtruediv__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __rxor__(y)[source]#
Name-vectorized version of
jax.Array.__rxor__
. Takes similar arguments asjax.Array.__rxor__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __sub__(b, /)#
Name-vectorized version of
jax.Array.__sub__
. Takes similar arguments asjax.Array.__sub__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __treescope_repr__(path: str | None, subtree_renderer: Any)[source]#
Treescope handler for named arrays.
- __truediv__(b, /)#
Name-vectorized version of
jax.Array.__truediv__
. Takes similar arguments asjax.Array.__truediv__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- __xor__(b, /)#
Name-vectorized version of
jax.Array.__xor__
. Takes similar arguments asjax.Array.__xor__
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- all(axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) Array [source]#
Name-vectorized version of array method
all
. Takes similar arguments asall
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- any(axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) Array [source]#
Name-vectorized version of array method
any
. Takes similar arguments asany
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- argmax(axis: int | None = None, out: None = None, keepdims: bool | None = None) Array [source]#
Name-vectorized version of array method
argmax
. Takes similar arguments asargmax
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- argmin(axis: int | None = None, out: None = None, keepdims: bool | None = None) Array [source]#
Name-vectorized version of array method
argmin
. Takes similar arguments asargmin
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- argpartition(kth: int, axis: int = -1) Array [source]#
Name-vectorized version of array method
argpartition
. Takes similar arguments asargpartition
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- argsort(axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) Array [source]#
Name-vectorized version of array method
argsort
. Takes similar arguments asargsort
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- abstract as_namedarrayview() NamedArrayView [source]#
Converts into a
NamedArrayView
, keeping positional axes.This function is usually not necessary for ordinary named-array manipulation, since
NamedArray
andNamedArrayView
define the same methods. However, it can be useful for simplifying library code that wishes to access the fields ofNamedArrayView
directly, or handle arbitrary named array objects without handling each case separately.Converting a
NamedArray
to aNamedArrayView
never involves any device computations. (The reverse is not true).- Returns:
An equivalent
NamedArrayView
for this array if it isn’t one already.
- astype(dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) Array [source]#
Name-vectorized version of array method
astype
. Takes similar arguments asastype
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- property at: _IndexUpdateHelper#
Helper property for index update functionality.
Lifts the
jax.Array.at[...]
syntax to also work for Penzai NamedArrays.Similar to direct indexing,
NamedArray.at[...]
can be indexed in two different ways, depending on whether the axes you wish to index are positional or named.When indexing positional axes, this operation follows the combination of
nmap
andjax.Array.at[...]
semantics. In particular,named_array.at[index].set(value)
is equivalent to
nmap(lambda arr, i, v: arr.at[i].set(v))(named_array, index, value)
The resulting array will have the same positional shape as the input array, and will have a named shape that is the union of the named shape of the input array and the names used in the indexer.
The semantics of dict-style indexing are similar, except that they apply to the named axes instead of the positional ones. In general, dict-style indexing will behave like positional indexing, where the requested axes are mapped to positional axes, then indexed, then mapped back to named axes where applicable. In other words,
# Update part of "foo" in place, and index into "bar" and "baz". named_array.at[{ "foo": pz.slice[2:5], "bar": 1, "baz": indexer_array }].set(value)
will behave like
result_structure = jax.eval_shape(lambda: named_array[{ "foo": pz.slice[2:5], "bar": 1, "baz": indexer_array }]) # Update positionally, but expect the name "foo" in `value`, and make # sure it maps correctly to "foo" in the input. named_array.untag_prefix("foo", "bar", "baz") .at[2:5, 1, indexer_array, ...] .set(value.broadcast_like(result_structure).untag_prefix("foo")) .tag_prefix("foo", "bar", "baz")
Specifically:
Axis names that are sliced by
index_dict
(e.g. mapped to a slice object) can appear in thevalue
. These will be used to update the corresponding slices of the array.Axis names that do not appear in
index_dict
will be broadcast against the corresponding axis names invalue
if they exist, following the semantics ofnmap
.The positional axes of
value
will be broadcast against the result of slicing the input array. This means that the suffix axes ofvalue
will correspond to positional axes of the input array, and the prefix axes will correspond to new axes introduced by advanced Numpy indexing. Often, the inputnamed_array
will have no positional axes, in which case the positional axes ofvalue
will be broadcast against the positional axes of the indexer arrays.
Note that, in order to update multiple positions along the same axis, you will need to use Numpy-style advanced indexing, by indexing with an array with a positional axis. For instance, to update indices 2 and 4 along axis “foo”, you can do
# Option 1: dict-style named_array.at[{ "foo": jnp.array([2, 4]) }].set( jnp.array([100, 101]) ) # Option 2: positional-style named_array.untag("foo").at[jnp.array([2, 4])] .set(jnp.array([100, 101])).tag("foo")
The following, will instead create a new axis “bar” with one update in each dimension:
# Option 1: dict-style named_array.at[{ "foo": pz.nx.wrap(jnp.array([2, 4])).tag("bar") }].set( pz.nx.wrap(jnp.array([100, 101])).tag("bar") ) # Option 2: positional-style named_array.untag("foo").at[ pz.nx.wrap(jnp.array([2, 4])).tag("bar") ].set( pz.nx.wrap(jnp.array([100, 101])).tag("bar") ).tag("foo")
The reason for this difference is that the named axis “bar” is vectorized over, and the behavior needs to be consistent regardless of whether “bar” appears in
named_array
or not.
- broadcast_like(other: NamedArrayBase | jax.typing.ArrayLike) NamedArrayBase [source]#
Broadcasts a named array to be compatible with another.
- Parameters:
other – Another named array.
- Returns:
A named array that has the same positional and named shapes as
other
(although it may also include extra named axes).
- broadcast_to(positional_shape: Sequence[int] = (), named_shape: Mapping[AxisName, int] | None = None) NamedArrayBase [source]#
Broadcasts a named array to a possibly-larger shape.
- Parameters:
positional_shape – Desired positional shape for the array. Will be broadcast using numpy broadcasting rules.
named_shape – Desired named shape for the array. Will be broadcast using
nmap
-style vectorizing rules (e.g. new named axes will be introduced if missing, but length-1 axes will not be broadcast).
- Returns:
A named array that has the given positional and named shapes. Note that if this array has axis names that are not in
named_shape
, these will be preserved in the answer as well.
- canonicalize() NamedArray [source]#
Ensures that the named axes are stored in a canonical order.
- Returns:
Equivalent
NamedArray
whose data array contains the positional axes followed by the named axes in sorted order.
- choose(choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') Array [source]#
Name-vectorized version of array method
choose
. Takes similar arguments aschoose
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- clip(min: ArrayLike | None = None, max: ArrayLike | None = None) Array [source]#
Name-vectorized version of array method
clip
. Takes similar arguments asclip
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- compress(condition: ArrayLike, axis: int | None = None, *, out: None = None, size: int | None = None, fill_value: ArrayLike = 0) Array [source]#
Name-vectorized version of array method
compress
. Takes similar arguments ascompress
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- conj() Array [source]#
Name-vectorized version of array method
conj
. Takes similar arguments asconj
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- conjugate() Array [source]#
Name-vectorized version of array method
conjugate
. Takes similar arguments asconjugate
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- cumprod(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) Array [source]#
Name-vectorized version of array method
cumprod
. Takes similar arguments ascumprod
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- cumsum(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) Array [source]#
Name-vectorized version of array method
cumsum
. Takes similar arguments ascumsum
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- diagonal(offset: int = 0, axis1: int = 0, axis2: int = 1) Array [source]#
Name-vectorized version of array method
diagonal
. Takes similar arguments asdiagonal
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- dot(b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) Array [source]#
Name-vectorized version of array method
dot
. Takes similar arguments asdot
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- abstract property dtype: np.dtype#
The dtype of the wrapped array.
- flatten(order: str = 'C') Array [source]#
Name-vectorized version of array method
flatten
. Takes similar arguments asflatten
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- imag#
Name-vectorized version of array method
imag
. Takes similar arguments asimag
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- item(*args) bool | int | float | complex [source]#
Name-vectorized version of array method
item
. Takes similar arguments asitem
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- mT#
Name-vectorized version of array method
mT
. Takes similar arguments asmT
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- max(axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) Array [source]#
Name-vectorized version of array method
max
. Takes similar arguments asmax
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- mean(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) Array [source]#
Name-vectorized version of array method
mean
. Takes similar arguments asmean
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- min(axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) Array [source]#
Name-vectorized version of array method
min
. Takes similar arguments asmin
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- nonzero(*, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) tuple[Array, ...] [source]#
Name-vectorized version of array method
nonzero
. Takes similar arguments asnonzero
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- order_as(*axis_order: AxisName) NamedArray [source]#
Ensures that the named axes are stored in this order, keeping them named.
This function can be used if it is important for the axis names to appear in a consistent order, e.g. to ensure that two
NamedArray
instances have exactly the same PyTree structure.If you want a canonical ordering for a named array that doesn’t involve knowing all the axis names in advance, you could do something like
array.order_as(*sorted(array.named_shape.keys()))
.See also
order_like
.- Parameters:
*axis_order – Axis names in the order they should appear in the data array. Must be a permutation of all of the axis names in this array.
- Returns:
Equivalent
NamedArray
whose data array contains the positional axes followed by the named axes in the given order.
- order_like(other: NamedArray | NamedArrayView) NamedArray | NamedArrayView [source]#
Ensures that this array’s PyTree structure matches another array’s.
This can be used to ensure that one named array has the same PyTree structure as another, so that the two can be jointly processed by non-namedarray-aware tree functions (e.g.
jax.tree_util
functions,jax.lax.cond
,jax.jvp
, etc).To ensure compatibility of entire PyTrees, you can use something like:
jax.tree_util.tree_map( lambda a, b: a.order_like(b), tree1, tree2, is_leaf=pz.nx.is_namedarray, )
- Parameters:
other – Another named array or named array view. Must have the same set of named axes as
self
. Ifother
is aNamedArrayView
,other
must also have the same number of positional axes.- Returns:
A new
NamedArray
orNamedArrayView
that has the content ofself
but is possibly transposed to have the axes appear in the same order asother
in the data array. If the arrays have the same named and positional shapes, the result will have the same PyTree structure asother
.
- prod(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) Array [source]#
Name-vectorized version of array method
prod
. Takes similar arguments asprod
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- ptp(axis: Axis = None, out: None = None, keepdims: bool = False) Array [source]#
Name-vectorized version of array method
ptp
. Takes similar arguments asptp
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- ravel(order: str = 'C') Array [source]#
Name-vectorized version of array method
ravel
. Takes similar arguments asravel
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- real#
Name-vectorized version of array method
real
. Takes similar arguments asreal
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- repeat(repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) Array [source]#
Name-vectorized version of array method
repeat
. Takes similar arguments asrepeat
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- reshape(*args: Any, order: str = 'C') Array [source]#
Name-vectorized version of array method
reshape
. Takes similar arguments asreshape
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- round(decimals: int = 0, out: None = None) Array [source]#
Name-vectorized version of array method
round
. Takes similar arguments asround
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- searchsorted(v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') Array [source]#
Name-vectorized version of array method
searchsorted
. Takes similar arguments assearchsorted
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- sort(axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) Array [source]#
Name-vectorized version of array method
sort
. Takes similar arguments assort
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- squeeze(axis: int | Sequence[int] | None = None) Array [source]#
Name-vectorized version of array method
squeeze
. Takes similar arguments assqueeze
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- std(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) Array [source]#
Name-vectorized version of array method
std
. Takes similar arguments asstd
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- sum(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) Array [source]#
Name-vectorized version of array method
sum
. Takes similar arguments assum
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- swapaxes(axis1: int, axis2: int) Array [source]#
Name-vectorized version of array method
swapaxes
. Takes similar arguments asswapaxes
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- abstract tag(*names) NamedArray [source]#
Attaches names to the positional axes of an array or view.
- Parameters:
*names – Axis names to assign to each positional axis in the array or view. Must be the same length as
positional_shape
; if you only want to tag a subset of axes, usetag_prefix
instead.- Raises:
ValueError – If the names are invalid, or if they aren’t the same length as
positional_shape
.- Returns:
A NamedArray with the given names assigned to the positional axes, and no remaining positional axes.
- tag_prefix(*axis_order: AxisName) NamedArray | NamedArrayView [source]#
Attaches names to the first positional axes in an array or view.
This is a version of
tag
that allows you to name only a subset of the array’s positional axes.- Parameters:
*axis_order – Axis names to make positional, in the order they should appear in the positional view.
- Returns:
A NamedArray or view with the given names assigned to the first positional axes, and whose positional shape includes only the suffix of axes that have not been given names.
- take(indices: ArrayLike, axis: int | None = None, out: None = None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value: StaticScalar | None = None) Array [source]#
Name-vectorized version of array method
take
. Takes similar arguments astake
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- trace(offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) Array [source]#
Name-vectorized version of array method
trace
. Takes similar arguments astrace
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- transpose(*args: Any) Array [source]#
Name-vectorized version of array method
transpose
. Takes similar arguments astranspose
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- abstract untag(*axis_order: AxisName) NamedArray | NamedArrayView [source]#
Produces a positional view of the requested axis names.
untag
can only be called on aNamedArray
orNamedArrayView
that does not have any positional axes. It produces a newNamedArrayView
where the axes with the requested names (the arguments to this function) are now treated as positional in the given order.If you want to use
untag
on an array that already has positional axes, you can useuntag_prefix
instead.- Parameters:
*axis_order – Axis names to make positional, in the order they should appear in the positional view.
- Raises:
ValueError – If this array already has positional axes, or if the provided axis ordering is not valid.
- Returns:
A view with the given axes treated as positional for the purposes of later calls to
apply
,nmap
, orwith_positional_prefix
. If passed an empty axis order, returns an ordinary NamedArray with no positional axes.
- untag_prefix(*axis_order: AxisName) NamedArray | NamedArrayView [source]#
Adds the requested axes to the front of the array’s positional axes.
This is a version of
untag
that can be called on NamedArrays or NamedArrayViews that already have positional axes.- Parameters:
*axis_order – Axis names to make positional, in the order they should appear in the positional view.
- Returns:
A view with the given axes treated as positional, followed by the existing positional axes.
- abstract unwrap(*names: AxisName) jax.Array [source]#
Unwraps this array, possibly mapping axis names to positional axes.
Unwrap can be called either on arrays with no named axes, or arrays with no positional axes (in which case
names
should be a permutation of its axis names).- Parameters:
*names – Sequence of axis names to map to positional axes, if this array has named axes. Shortand for
untag(*names).unwrap()
.- Returns:
An equivalent ordinary positional array.
- Raises:
ValueError – If the array has a mixture of positional and named axes, or if the names do not match the named axes.
- var(axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) Array [source]#
Name-vectorized version of array method
var
. Takes similar arguments asvar
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- view(dtype: DTypeLike | None = None, type: None = None) Array [source]#
Name-vectorized version of array method
view
. Takes similar arguments asview
but accepts and returns NamedArrays (or NamedArrayViews) in place of regular arrays.
- abstract with_positional_prefix() NamedArray [source]#
Ensures a view is a
NamedArray
by moving positional axes.The resulting
NamedArray
has the same named and positional shapes as this object, but the data array may be transposed so that all the positional axes are in the front. This makes it possible to manipulate those prefix axes safely usingjax.tree_util
or scan/map over them using JAX control flow primitives.- Returns:
An equivalent
NamedArray
for this view, or the originalNamedArray
if it already was one.