ArraySpec#

class penzai.core.shapecheck.ArraySpec[source]#

Bases: Struct

A non-leaf marker for a (named) array structure.

This is like a jax.ShapeDtypeStruct, but it is an empty PyTree node instead of being a leaf, so flattening it produces no children. It also supports named axes and dimension variables.

ArraySpec is used for shape checking as well as to annotate the expected shape and dtype of uninitialized parameters. It may appear in a model PyTree either inside an uninitialized parameter or inside shape-checking layers.

Note that the named_shape attribute specifically refers to the named shape of a Penzai NamedArray or NamedArrayView. Some internal JAX transforms (e.g. the deprecated xmap) can produce JAX values with their own internal named_shape attribute, but this will not be checked against the named_shape of an ArraySpec.

Variables:
  • shape (tuple[int | DimVar | MultiDimVar, ...]) – Positional shape of the eventual array that will be inserted here. Can include DimVar or MultiDimVar if it is being used for shape-checking.

  • dtype (jax.typing.DTypeLike) – Dtype of the eventual array that will be inserted here. Can be an abstract dtype (e.g. np.floating, which is actually an abstract base class and has type type) or a concrete array dtype (e.g. np.dtype("float32") which has type np.dtype). Abstract dtypes accept any concrete subdtype.

  • named_shape (Mapping[named_axes.AxisName | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]) – Named shape of the eventual array that will be inserted here. Can include a DimVar instances as values if it is being used for shape-checking. Can also include MultiDimVar instances as keys with the RemainingAxisPlaceholder sentinel as the value, to indicate an arbitrary collection of names.

Methods

__init__([shape, dtype, named_shape])

floating_named(named_shape)

Returns an ArraySpec with this named shape and np.floating dtype.

into_pytree()

Converts an ArraySpec into a (possibly wrapped) PyTree leaf.

Attributes

positional_shape

shape

named_shape

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

dtype#

alias of generic

classmethod floating_named(named_shape: Mapping[named_axes.AxisName | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]) ArraySpec[source]#

Returns an ArraySpec with this named shape and np.floating dtype.

into_pytree() jax.ShapeDtypeStruct | named_axes.NamedArray[source]#

Converts an ArraySpec into a (possibly wrapped) PyTree leaf.

By default, an ArraySpec has no PyTree children. This method can be used to convert it into a PyTree subtree that contains one leaf, a jax.ShapeDtypeStruct. This can then be used to e.g. restore parameters into a structure.

If this structure has a named shape, will return a NamedArray wrapping a ShapeDtypeStruct, with the parameter order inferred from the order of names in this ArraySpec. Otherwise, will return an ordinary ShapeDtypeStruct.

Returns:

A PyTree whose structure matches this structure.