ArraySpec#
- class penzai.core.shapecheck.ArraySpec[source]#
Bases:
Struct
A non-leaf marker for a (named) array structure.
This is like a
jax.ShapeDtypeStruct
, but it is an empty PyTree node instead of being a leaf, so flattening it produces no children. It also supports named axes and dimension variables.ArraySpec
is used for shape checking as well as to annotate the expected shape and dtype of uninitialized parameters. It may appear in a model PyTree either inside an uninitialized parameter or inside shape-checking layers.Note that the
named_shape
attribute specifically refers to the named shape of a PenzaiNamedArray
orNamedArrayView
. Some internal JAX transforms (e.g. the deprecatedxmap
) can produce JAX values with their own internalnamed_shape
attribute, but this will not be checked against thenamed_shape
of anArraySpec
.- Variables:
shape (tuple[int | DimVar | MultiDimVar, ...]) – Positional shape of the eventual array that will be inserted here. Can include
DimVar
orMultiDimVar
if it is being used for shape-checking.dtype (jax.typing.DTypeLike) – Dtype of the eventual array that will be inserted here. Can be an abstract dtype (e.g.
np.floating
, which is actually an abstract base class and has typetype
) or a concrete array dtype (e.g.np.dtype("float32")
which has typenp.dtype
). Abstract dtypes accept any concrete subdtype.named_shape (Mapping[named_axes.AxisName | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]) – Named shape of the eventual array that will be inserted here. Can include a
DimVar
instances as values if it is being used for shape-checking. Can also includeMultiDimVar
instances as keys with theRemainingAxisPlaceholder
sentinel as the value, to indicate an arbitrary collection of names.
Methods
__init__
([shape, dtype, named_shape])floating_named
(named_shape)Returns an
ArraySpec
with this named shape andnp.floating
dtype.Converts an
ArraySpec
into a (possibly wrapped) PyTree leaf.Attributes
positional_shape
shape
named_shape
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- classmethod floating_named(named_shape: Mapping[named_axes.AxisName | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]) ArraySpec [source]#
Returns an
ArraySpec
with this named shape andnp.floating
dtype.
- into_pytree() jax.ShapeDtypeStruct | named_axes.NamedArray [source]#
Converts an
ArraySpec
into a (possibly wrapped) PyTree leaf.By default, an
ArraySpec
has no PyTree children. This method can be used to convert it into a PyTree subtree that contains one leaf, ajax.ShapeDtypeStruct
. This can then be used to e.g. restore parameters into a structure.If this structure has a named shape, will return a NamedArray wrapping a
ShapeDtypeStruct
, with the parameter order inferred from the order of names in thisArraySpec
. Otherwise, will return an ordinaryShapeDtypeStruct
.- Returns:
A PyTree whose structure matches this structure.