NamedArrayView#

class penzai.core.named_axes.NamedArrayView[source]#

Bases: NamedArrayBase, Struct

A possibly-transposed view of an array with positional and named axes.

This view identifies a particular set of axes in a data array as “virtual positional axes”, which can be operated on by positional logic through nmap. Unlike NamedArray, the positional axes can be stored anywhere in the data array, not just as a prefix of its shape.

Instances of NamedArrayView are generally constructed by calling untag on a NamedArray, or as the return value of a nmap-ed function. Views are then either passed to a positional operation using nmap or wrapped instance methods, converted back to a NamedArray by reassigning names with tag, or converted to a NamedArray while moving any remaining positional axes to the front with with_positional_prefix.

Directly modifying the shape of data_array is not allowed, as it will break the invariants used in named axis tracking. If you need to slice axes of a named array, first make sure the logical positional axes are a prefix of data_array using with_positional_prefix.

Variables:
  • data_shape (tuple[int, ...]) – The required shape of the data array.

  • data_axis_for_logical_axis (tuple[int, ...]) – Maps the logical positional axes for this view to the true indices into data_array’s shape.

  • data_axis_for_name (dict[AxisName, int]) – Maps axis names to indices into data_array’s shape.

  • data_array (jax.Array) – The underlying positional-indexed array.

Inherited Attributes

T

Name-vectorized version of array method T.

at

Helper property for index update functionality.

imag

Name-vectorized version of array method imag.

mT

Name-vectorized version of array method mT.

real

Name-vectorized version of array method real.

Methods

__init__(data_shape, ...)

as_namedarrayview()

check_valid()

tag(*names)

untag(*axis_order)

unwrap(*names)

with_positional_prefix()

Converts a view into a proper NamedArray by moving positional axes.

Attributes

T

Name-vectorized version of array method T.

at

Helper property for index update functionality.

dtype

imag

Name-vectorized version of array method imag.

mT

Name-vectorized version of array method mT.

named_shape

positional_shape

real

Name-vectorized version of array method real.

data_shape

data_axis_for_logical_axis

data_axis_for_name

data_array

Inherited Methods

(expand to view inherited methods)

all([axis, out, keepdims, where])

Name-vectorized version of array method all.

any([axis, out, keepdims, where])

Name-vectorized version of array method any.

argmax([axis, out, keepdims])

Name-vectorized version of array method argmax.

argmin([axis, out, keepdims])

Name-vectorized version of array method argmin.

argpartition(kth[, axis])

Name-vectorized version of array method argpartition.

argsort([axis, kind, order, stable, descending])

Name-vectorized version of array method argsort.

astype(dtype[, copy, device])

Name-vectorized version of array method astype.

attributes_dict()

Constructs a dictionary with all of the fields in the class.

broadcast_like(other)

Broadcasts a named array to be compatible with another.

broadcast_to([positional_shape, named_shape])

Broadcasts a named array to a possibly-larger shape.

canonicalize()

Ensures that the named axes are stored in a canonical order.

choose(choices[, out, mode])

Name-vectorized version of array method choose.

clip([min, max])

Name-vectorized version of array method clip.

compress(condition[, axis, out, size, ...])

Name-vectorized version of array method compress.

conj()

Name-vectorized version of array method conj.

conjugate()

Name-vectorized version of array method conjugate.

cumprod([axis, dtype, out])

Name-vectorized version of array method cumprod.

cumsum([axis, dtype, out])

Name-vectorized version of array method cumsum.

diagonal([offset, axis1, axis2])

Name-vectorized version of array method diagonal.

dot(b, *[, precision, preferred_element_type])

Name-vectorized version of array method dot.

flatten([order])

Name-vectorized version of array method flatten.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

item(*args)

Name-vectorized version of array method item.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

max([axis, out, keepdims, initial, where])

Name-vectorized version of array method max.

mean([axis, dtype, out, keepdims, where])

Name-vectorized version of array method mean.

min([axis, out, keepdims, initial, where])

Name-vectorized version of array method min.

nonzero(*[, size, fill_value])

Name-vectorized version of array method nonzero.

order_as(*axis_order)

Ensures that the named axes are stored in this order, keeping them named.

order_like(other)

Ensures that this array's PyTree structure matches another array's.

prod([axis, dtype, out, keepdims, initial, ...])

Name-vectorized version of array method prod.

ptp([axis, out, keepdims])

Name-vectorized version of array method ptp.

ravel([order])

Name-vectorized version of array method ravel.

repeat(repeats[, axis, total_repeat_length])

Name-vectorized version of array method repeat.

reshape(*args[, order])

Name-vectorized version of array method reshape.

round([decimals, out])

Name-vectorized version of array method round.

searchsorted(v[, side, sorter, method])

Name-vectorized version of array method searchsorted.

select()

Wraps this struct in a selection, enabling functional-style mutations.

sort([axis, kind, order, stable, descending])

Name-vectorized version of array method sort.

squeeze([axis])

Name-vectorized version of array method squeeze.

std([axis, dtype, out, ddof, keepdims, ...])

Name-vectorized version of array method std.

sum([axis, dtype, out, keepdims, initial, ...])

Name-vectorized version of array method sum.

swapaxes(axis1, axis2)

Name-vectorized version of array method swapaxes.

tag_prefix(*axis_order)

Attaches names to the first positional axes in an array or view.

take(indices[, axis, out, mode, ...])

Name-vectorized version of array method take.

trace([offset, axis1, axis2, dtype, out])

Name-vectorized version of array method trace.

transpose(*args)

Name-vectorized version of array method transpose.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

untag_prefix(*axis_order)

Adds the requested axes to the front of the array's positional axes.

var([axis, dtype, out, ddof, keepdims, ...])

Name-vectorized version of array method var.

view([dtype, type])

Name-vectorized version of array method view.

with_positional_prefix() NamedArray[source]#

Converts a view into a proper NamedArray by moving positional axes.

The resulting NamedArray has the same named and positional shapes as this view, but the data array may be transposed so that all the positional axes are in the front. This makes it possible to manipulate those prefix axes safely using jax.tree_util or scan/map over them using JAX control flow primitives.

Returns:

An equivalent NamedArray for this view.