NamedArray#
- class penzai.core.named_axes.NamedArray[source]#
Bases:
NamedArrayBase,StructA multidimensional array with a combination of positional and named axes.
Conceptually,
NamedArrays can have positional axes like an ordinaryjax.Array, but can also have explicit named axes. Operations onNamedArrays always act only on the positional axes, and are vectorized (or “lifted”) over the named axes. To apply operations to a named axis, you can firstuntagthat named axis as a positional axis, then apply the operation as normal. (This is intentional to avoid having to re-define separate “named” versions of every JAX or Numpy function.)Internally, a
NamedArraystores its array data in thedata_arrayattribute. The positional axes always appear as a prefix of the data array’s shape; this means stacking NamedArrays along their first axis (e.g. axis=0) or iterating over this axis (e.g. withjax.lax.scan) will work correctly forNamedArrays with a nonempty positional shape. However, to avoid unnecessary transpositions, some operations on a NamedArray will produce aNamedArrayViewinstead of aNamedArray. (NamedArrayViews define the same methods asNamedArrays but have a more complex data representation.)Operations on
NamedArrays generally involve constructing positional views of he axes you need to operate on:To run primitive operations (like
jnp.sumorjax.nn.softmax) on aNamedArray, you can calluntagto mark specific axes as positional, then usenamed_axes.nmap(or wrapped instance methods like.sum) to run the primitive positional operation over that locally-positional view. Any positional axes produced by the operation can be rebound to names usingtag.To slice an axis out of a
NamedArray(e.g. as input tojax.lax.scan), you can first move the given axis names to the front of the array using.untag(...).with_positional_prefix(), then do atree_mapover the internaldata_arrayand slice the first axis (which is whatjax.lax.scandoes internally).To stack
NamedArrays together along an axis (e.g. for the output ofjax.lax.scan), you can just stack them normally, then usetagto give a name to the new axis. (If you want to stackNamedArrayView`s, convert them to ``NamedArray`s first usingwith_positional_prefix.)
Note that it’s only safe to manipulate the prefix axes of
data_arraywhich do not have names. Any operations on the named axes should first assign positions usinguntag, and then either do the operation insidenmapor move those positions to the prefix ofdata_arrayusingwith_positional_prefix.The internal ordering of the named axes is part of the PyTree structure of a
NamedArrayorNamedArrayView, which means passing them through JAX operations sometimes requires re-ordering the named axes to ensure a consistent structure. To make one named array have the same PyTree structure as another, you can usefirst.order_like(second)or, for trees,jax.tree_util.tree_map( lambda a, b: a.order_like(b), tree1, tree2, is_leaf=pz.nx.is_namedarray, )
You can also use
canonicalizeto reorder the internal axes in a canonical order; any named arrays with equivalent shapes will have the same PyTree structure after callingcanonicalize.- Variables:
named_axes (collections.OrderedDict[AxisName, int]) – An ordered map from axis names to their lengths. The values must be a suffix of
data_array.shape, and the keys give names to each of the suffix dimensions of the array. Usually, this will be the same length asdata_array.shape.data_array (jax.Array) – The underlying positional-indexed array.
Inherited Attributes
TName-vectorized version of array method
T.atHelper property for index update functionality.
imagName-vectorized version of array method
imag.mTName-vectorized version of array method
mT.realName-vectorized version of array method
real.Methods
__init__(named_axes, data_array)as_namedarrayview()check_valid()tag(*names)untag(*axis_order)unwrap(*names)with_positional_prefix()wrap(array, *names)Wraps a positional array as a
NamedArray.Attributes
TName-vectorized version of array method
T.atHelper property for index update functionality.
dtypeimagName-vectorized version of array method
imag.mTName-vectorized version of array method
mT.named_shapepositional_shaperealName-vectorized version of array method
real.named_axesdata_arrayInherited Methods
(expand to view inherited methods)
all([axis, out, keepdims, where])Name-vectorized version of array method
all.any([axis, out, keepdims, where])Name-vectorized version of array method
any.argmax([axis, out, keepdims])Name-vectorized version of array method
argmax.argmin([axis, out, keepdims])Name-vectorized version of array method
argmin.argpartition(kth[, axis])Name-vectorized version of array method
argpartition.argsort([axis, kind, order, stable, descending])Name-vectorized version of array method
argsort.astype(dtype[, copy, device])Name-vectorized version of array method
astype.attributes_dict()Constructs a dictionary with all of the fields in the class.
broadcast_like(other)Broadcasts a named array to be compatible with another.
broadcast_to([positional_shape, named_shape])Broadcasts a named array to a possibly-larger shape.
canonicalize()Ensures that the named axes are stored in a canonical order.
choose(choices[, out, mode])Name-vectorized version of array method
choose.clip([min, max])Name-vectorized version of array method
clip.compress(condition[, axis, out, size, ...])Name-vectorized version of array method
compress.conj()Name-vectorized version of array method
conj.conjugate()Name-vectorized version of array method
conjugate.cumprod([axis, dtype, out])Name-vectorized version of array method
cumprod.cumsum([axis, dtype, out])Name-vectorized version of array method
cumsum.diagonal([offset, axis1, axis2])Name-vectorized version of array method
diagonal.dot(b, *[, precision, preferred_element_type])Name-vectorized version of array method
dot.flatten([order])Name-vectorized version of array method
flatten.from_attributes(**field_values)Directly instantiates a struct given all of its fields.
item(*args)Name-vectorized version of array method
item.key_for_field(field_name)Generates a JAX PyTree key for a given field name.
max([axis, out, keepdims, initial, where])Name-vectorized version of array method
max.mean([axis, dtype, out, keepdims, where])Name-vectorized version of array method
mean.min([axis, out, keepdims, initial, where])Name-vectorized version of array method
min.nonzero(*[, fill_value, size])Name-vectorized version of array method
nonzero.order_as(*axis_order)Ensures that the named axes are stored in this order, keeping them named.
order_like(other)Ensures that this array's PyTree structure matches another array's.
prod([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
prod.ptp([axis, out, keepdims])Name-vectorized version of array method
ptp.ravel([order])Name-vectorized version of array method
ravel.repeat(repeats[, axis, total_repeat_length])Name-vectorized version of array method
repeat.reshape(*args[, order])Name-vectorized version of array method
reshape.round([decimals, out])Name-vectorized version of array method
round.searchsorted(v[, side, sorter, method])Name-vectorized version of array method
searchsorted.select()Wraps this struct in a selection, enabling functional-style mutations.
sort([axis, kind, order, stable, descending])Name-vectorized version of array method
sort.squeeze([axis])Name-vectorized version of array method
squeeze.std([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
std.sum([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
sum.swapaxes(axis1, axis2)Name-vectorized version of array method
swapaxes.tag_prefix(*axis_order)Attaches names to the first positional axes in an array or view.
take(indices[, axis, out, mode, ...])Name-vectorized version of array method
take.trace([offset, axis1, axis2, dtype, out])Name-vectorized version of array method
trace.transpose(*args)Name-vectorized version of array method
transpose.tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
untag_prefix(*axis_order)Adds the requested axes to the front of the array's positional axes.
var([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
var.view([dtype, type])Name-vectorized version of array method
view.- classmethod wrap(array: jax.typing.ArrayLike, *names: AxisName) NamedArray[source]#
Wraps a positional array as a
NamedArray.- Parameters:
array – Array to wrap.
*names – Optional names for the axes of the array. If provided, must be the same length as the array’s shape. This is a convenience wrapper so that you can call
wrap(array, "foo", "bar")instead ofwrap(array).tag("foo", "bar").
- Returns:
An equivalent
NamedArrayfor the given array. Ifnamesis provided, the resulting array will have those names assigned to the corresponding axes. Otherwise, the resulting array will have a positional shape.