NamedArrayView#
- class penzai.core.named_axes.NamedArrayView[source]#
Bases:
NamedArrayBase
,Struct
A possibly-transposed view of an array with positional and named axes.
This view identifies a particular set of axes in a data array as “virtual positional axes”, which can be operated on by positional logic through
nmap
. UnlikeNamedArray
, the positional axes can be stored anywhere in the data array, not just as a prefix of its shape.Instances of
NamedArrayView
are generally constructed by callinguntag
on a NamedArray, or as the return value of anmap
-ed function. Views are then either passed to a positional operation usingnmap
or wrapped instance methods, converted back to a NamedArray by reassigning names withtag
, or converted to a NamedArray while moving any remaining positional axes to the front withwith_positional_prefix
.Directly modifying the shape of
data_array
is not allowed, as it will break the invariants used in named axis tracking. If you need to slice axes of a named array, first make sure the logical positional axes are a prefix ofdata_array
usingwith_positional_prefix
.- Variables:
data_shape (tuple[int, ...]) – The required shape of the data array.
data_axis_for_logical_axis (tuple[int, ...]) – Maps the logical positional axes for this view to the true indices into
data_array
’s shape.data_axis_for_name (dict[AxisName, int]) – Maps axis names to indices into
data_array
’s shape.data_array (jax.Array) – The underlying positional-indexed array.
Inherited Attributes
T
Name-vectorized version of array method
T
.at
Helper property for index update functionality.
imag
Name-vectorized version of array method
imag
.mT
Name-vectorized version of array method
mT
.real
Name-vectorized version of array method
real
.Methods
__init__
(data_shape, ...)as_namedarrayview
()check_valid
()tag
(*names)untag
(*axis_order)unwrap
(*names)Converts a view into a proper
NamedArray
by moving positional axes.Attributes
T
Name-vectorized version of array method
T
.at
Helper property for index update functionality.
dtype
imag
Name-vectorized version of array method
imag
.mT
Name-vectorized version of array method
mT
.named_shape
positional_shape
real
Name-vectorized version of array method
real
.data_shape
data_axis_for_logical_axis
data_axis_for_name
data_array
Inherited Methods
(expand to view inherited methods)
all
([axis, out, keepdims, where])Name-vectorized version of array method
all
.any
([axis, out, keepdims, where])Name-vectorized version of array method
any
.argmax
([axis, out, keepdims])Name-vectorized version of array method
argmax
.argmin
([axis, out, keepdims])Name-vectorized version of array method
argmin
.argpartition
(kth[, axis])Name-vectorized version of array method
argpartition
.argsort
([axis, kind, order, stable, descending])Name-vectorized version of array method
argsort
.astype
(dtype[, copy, device])Name-vectorized version of array method
astype
.attributes_dict
()Constructs a dictionary with all of the fields in the class.
broadcast_like
(other)Broadcasts a named array to be compatible with another.
broadcast_to
([positional_shape, named_shape])Broadcasts a named array to a possibly-larger shape.
canonicalize
()Ensures that the named axes are stored in a canonical order.
choose
(choices[, out, mode])Name-vectorized version of array method
choose
.clip
([min, max])Name-vectorized version of array method
clip
.compress
(condition[, axis, out, size, ...])Name-vectorized version of array method
compress
.conj
()Name-vectorized version of array method
conj
.conjugate
()Name-vectorized version of array method
conjugate
.cumprod
([axis, dtype, out])Name-vectorized version of array method
cumprod
.cumsum
([axis, dtype, out])Name-vectorized version of array method
cumsum
.diagonal
([offset, axis1, axis2])Name-vectorized version of array method
diagonal
.dot
(b, *[, precision, preferred_element_type])Name-vectorized version of array method
dot
.flatten
([order])Name-vectorized version of array method
flatten
.from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
item
(*args)Name-vectorized version of array method
item
.key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
max
([axis, out, keepdims, initial, where])Name-vectorized version of array method
max
.mean
([axis, dtype, out, keepdims, where])Name-vectorized version of array method
mean
.min
([axis, out, keepdims, initial, where])Name-vectorized version of array method
min
.nonzero
(*[, size, fill_value])Name-vectorized version of array method
nonzero
.order_as
(*axis_order)Ensures that the named axes are stored in this order, keeping them named.
order_like
(other)Ensures that this array's PyTree structure matches another array's.
prod
([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
prod
.ptp
([axis, out, keepdims])Name-vectorized version of array method
ptp
.ravel
([order])Name-vectorized version of array method
ravel
.repeat
(repeats[, axis, total_repeat_length])Name-vectorized version of array method
repeat
.reshape
(*args[, order])Name-vectorized version of array method
reshape
.round
([decimals, out])Name-vectorized version of array method
round
.searchsorted
(v[, side, sorter, method])Name-vectorized version of array method
searchsorted
.select
()Wraps this struct in a selection, enabling functional-style mutations.
sort
([axis, kind, order, stable, descending])Name-vectorized version of array method
sort
.squeeze
([axis])Name-vectorized version of array method
squeeze
.std
([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
std
.sum
([axis, dtype, out, keepdims, initial, ...])Name-vectorized version of array method
sum
.swapaxes
(axis1, axis2)Name-vectorized version of array method
swapaxes
.tag_prefix
(*axis_order)Attaches names to the first positional axes in an array or view.
take
(indices[, axis, out, mode, ...])Name-vectorized version of array method
take
.trace
([offset, axis1, axis2, dtype, out])Name-vectorized version of array method
trace
.transpose
(*args)Name-vectorized version of array method
transpose
.tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
untag_prefix
(*axis_order)Adds the requested axes to the front of the array's positional axes.
var
([axis, dtype, out, ddof, keepdims, ...])Name-vectorized version of array method
var
.view
([dtype, type])Name-vectorized version of array method
view
.- with_positional_prefix() NamedArray [source]#
Converts a view into a proper
NamedArray
by moving positional axes.The resulting
NamedArray
has the same named and positional shapes as this view, but the data array may be transposed so that all the positional axes are in the front. This makes it possible to manipulate those prefix axes safely usingjax.tree_util
or scan/map over them using JAX control flow primitives.- Returns:
An equivalent
NamedArray
for this view.