NamedArray#
- class penzai.core.named_axes.NamedArray[source]#
Bases:
NamedArrayBase
,Struct
A multidimensional array with a combination of positional and named axes.
Conceptually,
NamedArray
s can have positional axes like an ordinaryjax.Array
, but can also have explicit named axes. Operations onNamedArray
s always act only on the positional axes, and are vectorized (or “lifted”) over the named axes. To apply operations to a named axis, you can firstuntag
that named axis as a positional axis, then apply the operation as normal. (This is intentional to avoid having to re-define separate “named” versions of every JAX or Numpy function.)Internally, a
NamedArray
stores its array data in thedata_array
attribute. The positional axes always appear as a prefix of the data array’s shape; this means stacking NamedArrays along their first axis (e.g. axis=0) or iterating over this axis (e.g. withjax.lax.scan
) will work correctly forNamedArray
s with a nonempty positional shape. However, to avoid unnecessary transpositions, some operations on a NamedArray will produce aNamedArrayView
instead of aNamedArray
. (NamedArrayView
s define the same methods asNamedArray
s but have a more complex data representation.)Operations on
NamedArray
s generally involve constructing positional views of he axes you need to operate on:To run primitive operations (like
jnp.sum
orjax.nn.softmax
) on aNamedArray
, you can calluntag
to mark specific axes as positional, then usenamed_axes.nmap
(or wrapped instance methods like.sum
) to run the primitive positional operation over that locally-positional view. Any positional axes produced by the operation can be rebound to names usingtag
.To slice an axis out of a
NamedArray
(e.g. as input tojax.lax.scan
), you can first move the given axis names to the front of the array using.untag(...).with_positional_prefix()
, then do atree_map
over the internaldata_array
and slice the first axis (which is whatjax.lax.scan
does internally).To stack
NamedArray
s together along an axis (e.g. for the output ofjax.lax.scan
), you can just stack them normally, then usetag
to give a name to the new axis. (If you want to stackNamedArrayView`s, convert them to ``NamedArray`
s first usingwith_positional_prefix
.)
Note that it’s only safe to manipulate the prefix axes of
data_array
which do not have names. Any operations on the named axes should first assign positions usinguntag
, and then either do the operation insidenmap
or move those positions to the prefix ofdata_array
usingwith_positional_prefix
.The internal ordering of the named axes is part of the PyTree structure of a
NamedArray
orNamedArrayView
, which means passing them through JAX operations sometimes requires re-ordering the named axes to ensure a consistent structure. To make one named array have the same PyTree structure as another, you can usefirst.order_like(second)
or, for trees,jax.tree_util.tree_map( lambda a, b: a.order_like(b), tree1, tree2, is_leaf=pz.nx.is_namedarray, )
You can also use
canonicalize
to reorder the internal axes in a canonical order; any named arrays with equivalent shapes will have the same PyTree structure after callingcanonicalize
.- Variables:
named_axes (collections.OrderedDict[AxisName, int]) – An ordered map from axis names to their lengths. The values must be a suffix of
data_array.shape
, and the keys give names to each of the suffix dimensions of the array. Usually, this will be the same length asdata_array.shape
.data_array (jax.Array) – The underlying positional-indexed array.
Inherited Attributes
T
Name-vectorized version of array method
T
.at
Helper property for index update functionality.
imag
Name-vectorized version of array method
imag
.mT
Name-vectorized version of array method
mT
.real
Name-vectorized version of array method
real
.Methods
__init__
(named_axes, data_array)as_namedarrayview
()check_valid
()tag
(*names)untag
(*axis_order)unwrap
(*names)with_positional_prefix
()wrap
(array, *names)Wraps a positional array as a
NamedArray
.Attributes
T
Name-vectorized version of array method
T
.at
Helper property for index update functionality.
dtype
imag
Name-vectorized version of array method
imag
.mT
Name-vectorized version of array method
mT
.named_shape
positional_shape
real
Name-vectorized version of array method
real
.named_axes
data_array
Inherited Methods
(expand to view inherited methods)
all
([axis, out, keepdims, where])Name-vectorized version of array method
all
.any
([axis, out, keepdims, where])Name-vectorized version of array method
any
.argmax
([axis, out, keepdims])Name-vectorized version of array method
argmax
.argmin
([axis, out, keepdims])Name-vectorized version of array method
argmin
.argpartition
(kth[, axis])Name-vectorized version of array method
argpartition
.argsort
([axis, kind, order, stable, descending])Name-vectorized version of array method
argsort
.astype
(dtype[, copy, device])Name-vectorized version of array method
astype
.attributes_dict
()Constructs a dictionary with all of the fields in the class.
broadcast_like
(other)Broadcasts a named array to be compatible with another.
broadcast_to
([positional_shape, named_shape])Broadcasts a named array to a possibly-larger shape.
canonicalize
()Ensures that the named axes are stored in a canonical order.
choose
(choices[, out, mode])Name-vectorized version of array method
choose
.clip
([min, max])Name-vectorized version of array method
clip
.compress
(condition[, axis, out, size, ...])Name-vectorized version of array method
compress
.conj
()Name-vectorized version of array method
conj
.conjugate
()Name-vectorized version of array method
conjugate
.cumprod
([axis, dtype, out])Name-vectorized version of array method
cumprod
.cumsum
([axis, dtype, out])Name-vectorized version of array method
cumsum
.diagonal
([offset, axis1, axis2])Name-vectorized version of array method
diagonal
.dot
(b, *[, precision, preferred_element_type])Name-vectorized version of array method
dot
.flatten
([order])Name-vectorized version of array method
flatten
.from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
item
(*args)Name-vectorized version of array method
item
.key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
max
([axis, out, keepdims, initial, where])Name-vectorized version of array method
max
.mean
([axis, dtype, out, keepdims, where])Name-vectorized version of array method
mean
.min
([axis, out, keepdims, initial, where])Name-vectorized version of array method
min
.nonzero
(*[, size, fill_value])Name-vectorized version of array method
nonzero
.order_as
(*axis_order)Ensures that the named axes are stored in this order, keeping them named.
order_like
(other)Ensures that this array's PyTree structure matches another array's.
prod
([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
prod
.ptp
([axis, out, keepdims])Name-vectorized version of array method
ptp
.ravel
([order])Name-vectorized version of array method
ravel
.repeat
(repeats[, axis, total_repeat_length])Name-vectorized version of array method
repeat
.reshape
(*args[, order])Name-vectorized version of array method
reshape
.round
([decimals, out])Name-vectorized version of array method
round
.searchsorted
(v[, side, sorter, method])Name-vectorized version of array method
searchsorted
.select
()Wraps this struct in a selection, enabling functional-style mutations.
sort
([axis, kind, order, stable, descending])Name-vectorized version of array method
sort
.squeeze
([axis])Name-vectorized version of array method
squeeze
.std
([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
std
.sum
([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
sum
.swapaxes
(axis1, axis2)Name-vectorized version of array method
swapaxes
.tag_prefix
(*axis_order)Attaches names to the first positional axes in an array or view.
take
(indices[, axis, out, mode, ...])Name-vectorized version of array method
take
.trace
([offset, axis1, axis2, dtype, out])Name-vectorized version of array method
trace
.transpose
(*args)Name-vectorized version of array method
transpose
.tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
untag_prefix
(*axis_order)Adds the requested axes to the front of the array's positional axes.
var
([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
var
.view
([dtype, type])Name-vectorized version of array method
view
.- classmethod wrap(array: jax.typing.ArrayLike, *names: AxisName) NamedArray [source]#
Wraps a positional array as a
NamedArray
.- Parameters:
array – Array to wrap.
*names – Optional names for the axes of the array. If provided, must be the same length as the array’s shape. This is a convenience wrapper so that you can call
wrap(array, "foo", "bar")
instead ofwrap(array).tag("foo", "bar")
.
- Returns:
An equivalent
NamedArray
for the given array. Ifnames
is provided, the resulting array will have those names assigned to the corresponding axes. Otherwise, the resulting array will have a positional shape.