ArraySpec#
- class penzai.core.shapecheck.ArraySpec[source]#
Bases:
StructA non-leaf marker for a (named) array structure.
This is like a
jax.ShapeDtypeStruct, but it is an empty PyTree node instead of being a leaf, so flattening it produces no children. It also supports named axes and dimension variables.ArraySpecis used for shape checking as well as to annotate the expected shape and dtype of uninitialized parameters. It may appear in a model PyTree either inside an uninitialized parameter or inside shape-checking layers.Note that the
named_shapeattribute specifically refers to the named shape of a PenzaiNamedArrayorNamedArrayView. Some internal JAX transforms (e.g. the deprecatedxmap) can produce JAX values with their own internalnamed_shapeattribute, but this will not be checked against thenamed_shapeof anArraySpec.- Variables:
shape (tuple[int | DimVar | MultiDimVar, ...]) – Positional shape of the eventual array that will be inserted here. Can include
DimVarorMultiDimVarif it is being used for shape-checking.dtype (jax.typing.DTypeLike) – Dtype of the eventual array that will be inserted here. Can be an abstract dtype (e.g.
np.floating, which is actually an abstract base class and has typetype) or a concrete array dtype (e.g.np.dtype("float32")which has typenp.dtype). Abstract dtypes accept any concrete subdtype.named_shape (Mapping[named_axes.AxisName | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]) – Named shape of the eventual array that will be inserted here. Can include a
DimVarinstances as values if it is being used for shape-checking. Can also includeMultiDimVarinstances as keys with theRemainingAxisPlaceholdersentinel as the value, to indicate an arbitrary collection of names.
Methods
__init__([shape, dtype, named_shape])floating_named(named_shape)Returns an
ArraySpecwith this named shape andnp.floatingdtype.Converts an
ArraySpecinto a (possibly wrapped) PyTree leaf.Attributes
positional_shapeshapenamed_shapeInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- classmethod floating_named(named_shape: Mapping[named_axes.AxisName | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]) ArraySpec[source]#
Returns an
ArraySpecwith this named shape andnp.floatingdtype.
- into_pytree() jax.ShapeDtypeStruct | named_axes.NamedArray[source]#
Converts an
ArraySpecinto a (possibly wrapped) PyTree leaf.By default, an
ArraySpechas no PyTree children. This method can be used to convert it into a PyTree subtree that contains one leaf, ajax.ShapeDtypeStruct. This can then be used to e.g. restore parameters into a structure.If this structure has a named shape, will return a NamedArray wrapping a
ShapeDtypeStruct, with the parameter order inferred from the order of names in thisArraySpec. Otherwise, will return an ordinaryShapeDtypeStruct.- Returns:
A PyTree whose structure matches this structure.