Struct#
- class penzai.core.struct.Struct[source]#
Bases:
objectBase class for penzai PyTree structures.
Structis the common base class for most objects in penzai. Structs are both frozen dataclasses and JAX PyTree nodes, with their fields annotations specifying which attributes contain JAX-traversible subtrees or numeric data and which attributes do not.Structis heavily inspired by equinox’sequinox.Module, and works much in the same way. However, there are a few differences:Every non-abstract
Structmust be explicitly registered as a dataclass pytree using the decoratorpenzai.pytree_dataclass, so that readers of the code can tell the class’s semantics differ from that of an ordinary Python class.The
pytree_dataclassdecorator supports additional configuration via keyword arguments, similar to the original dataclass decorator, and adds a few other features as well. In particular, by default attributes are NOT inherited from parent dataclasses, and__init__is modified to allow easier assignment to immutable fields; seepenzai.pytree_dataclassfor details.__init__follows normal dataclass rules: an__init__will be generated unless__init__is defined orinit=Falseis passed, similar to ordinary dataclass wrappers and in line with common typechecker expectations. However, to prevent accidentally overwriting their parent class’s__init__instead of inheriting it, subclasses of classes with custom__init__implementations must explicitly opt in to this behavior by settingoverwrite_parent_init=True, or opt out withinit=False. (Equinox modules instead try to always inherit__init__in this case.)Some convenient common methods for building, destructuring, and visualizing structs are defined by default.
Some equinox-specific features are not supported. Specifically, bound methods are not wrapped in Partial, and Equinox’s “wrapped modules” are not supported.
Methods
Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
Flattens this tree node.
Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
Computes a CSS color to display for this object in treescope.
Inherited Methods
(expand to view inherited methods)
__init__()- final attributes_dict() dict[str, Any][source]#
Constructs a dictionary with all of the fields in the class.
The result of this should be passable back to
from_attributesto rebuild (a copy of) the object.- Returns:
A dictionary containing all of the dataclass fields of the object.
- final classmethod from_attributes(**field_values) T[source]#
Directly instantiates a struct given all of its fields.
Structs can override
__init__to have arbitrary custom behavior, but this may make it difficult to construct new instances of structs with particular field values. This function makes it possible to directly instantiate an instance of a struct with given attributes.(Note: Overriding
__init__in aStructsubclass is usually discouraged.)The main purpose of this method is to enable easier serialization and deserialization of structs. Callers of this method are responsible for maintaining any invariants expected by the class.
- Parameters:
**field_values – Values for each of the struct’s fields.
- Returns:
A new instance of the class.
- key_for_field(field_name: str) Hashable[source]#
Generates a JAX PyTree key for a given field name.
This can be overridden if more control over JAX key paths is needed.
- Parameters:
field_name – The field name to construct a key for.
- Returns:
A hashable key to use in JAX PyTree paths.
- final select() selectors.Selection[Struct][source]#
Wraps this struct in a selection, enabling functional-style mutations.
This is a convenience wrapper around selectors.select to enable easier selection, using syntax like:
struct.select().at(lambda b: b.foo[3].bar).apply(baz) struct.select().at_instances_of(Sequential).at_children().apply(qux)
See documentation for
selectors.Selectionfor supported attributes.- Returns:
A singleton selection containing this struct.
- final tree_flatten() tuple[Sequence[Any], Any][source]#
Flattens this tree node.
See
jax.tree_util.register_pytree_with_keys_class.This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation. If you must override this for an advanced use case, consider using
pytree_dataclasswithout subclassingStruct.- Returns:
Children of the node, along with static metadata.
- final tree_flatten_with_keys() tuple[Sequence[tuple[Any, Any]], Any][source]#
Flattens this tree node with keys.
See
jax.tree_util.register_pytree_with_keys_class.This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation (and in particular, on
key_for_fieldproducing the JAX keypath keys for each field). If you must override this for an advanced use case, consider usingpytree_dataclasswithout subclassingStruct.- Returns:
(key, child)pairs for the node, along with static metadata.
- final classmethod tree_unflatten(aux_data: Any, children: Sequence[Any]) Struct[source]#
Unflattens this tree node.
See
jax.tree_util.register_pytree_with_keys_class.This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation. If you must override this for an advanced use case, consider using
pytree_dataclasswithout subclassingStruct.- Parameters:
aux_data – Auxiliary data, returned from the second argument of
tree_flatten_with_keys(ortree_flatten).children – Sequence of children from the first argument of
tree_flatten_with_keys(ortree_flatten).
- Returns:
An instance of the struct.
- treescope_color() str | tuple[str, str][source]#
Computes a CSS color to display for this object in treescope.
This function can be overridden to change the color for a particular object in treescope, without having to register a new handler.
- Returns:
A CSS color string to use as a background/highlight color for this object. Alternatively, a tuple of (border, fill) CSS colors.