Struct#
- class penzai.core.struct.Struct[source]#
Bases:
object
Base class for penzai PyTree structures.
Struct
is the common base class for most objects in penzai. Structs are both frozen dataclasses and JAX PyTree nodes, with their fields annotations specifying which attributes contain JAX-traversible subtrees or numeric data and which attributes do not.Struct
is heavily inspired by equinox’sequinox.Module
, and works much in the same way. However, there are a few differences:Every non-abstract
Struct
must be explicitly registered as a dataclass pytree using the decoratorpenzai.pytree_dataclass
, so that readers of the code can tell the class’s semantics differ from that of an ordinary Python class.The
pytree_dataclass
decorator supports additional configuration via keyword arguments, similar to the original dataclass decorator, and adds a few other features as well. In particular, by default attributes are NOT inherited from parent dataclasses, and__init__
is modified to allow easier assignment to immutable fields; seepenzai.pytree_dataclass
for details.__init__
follows normal dataclass rules: an__init__
will be generated unless__init__
is defined orinit=False
is passed, similar to ordinary dataclass wrappers and in line with common typechecker expectations. However, to prevent accidentally overwriting their parent class’s__init__
instead of inheriting it, subclasses of classes with custom__init__
implementations must explicitly opt in to this behavior by settingoverwrite_parent_init=True
, or opt out withinit=False
. (Equinox modules instead try to always inherit__init__
in this case.)Some convenient common methods for building, destructuring, and visualizing structs are defined by default.
Some equinox-specific features are not supported. Specifically, bound methods are not wrapped in Partial, and Equinox’s “wrapped modules” are not supported.
Methods
Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
Flattens this tree node.
Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
Computes a CSS color to display for this object in treescope.
Inherited Methods
(expand to view inherited methods)
__init__
()- final attributes_dict() dict[str, Any] [source]#
Constructs a dictionary with all of the fields in the class.
The result of this should be passable back to
from_attributes
to rebuild (a copy of) the object.- Returns:
A dictionary containing all of the dataclass fields of the object.
- final classmethod from_attributes(**field_values) T [source]#
Directly instantiates a struct given all of its fields.
Structs can override
__init__
to have arbitrary custom behavior, but this may make it difficult to construct new instances of structs with particular field values. This function makes it possible to directly instantiate an instance of a struct with given attributes.(Note: Overriding
__init__
in aStruct
subclass is usually discouraged.)The main purpose of this method is to enable easier serialization and deserialization of structs. Callers of this method are responsible for maintaining any invariants expected by the class.
- Parameters:
**field_values – Values for each of the struct’s fields.
- Returns:
A new instance of the class.
- key_for_field(field_name: str) Hashable [source]#
Generates a JAX PyTree key for a given field name.
This can be overridden if more control over JAX key paths is needed.
- Parameters:
field_name – The field name to construct a key for.
- Returns:
A hashable key to use in JAX PyTree paths.
- final select() selectors.Selection[Struct] [source]#
Wraps this struct in a selection, enabling functional-style mutations.
This is a convenience wrapper around selectors.select to enable easier selection, using syntax like:
struct.select().at(lambda b: b.foo[3].bar).apply(baz) struct.select().at_instances_of(Sequential).at_children().apply(qux)
See documentation for
selectors.Selection
for supported attributes.- Returns:
A singleton selection containing this struct.
- final tree_flatten() tuple[Sequence[Any], Any] [source]#
Flattens this tree node.
See
jax.tree_util.register_pytree_with_keys_class
.This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation. If you must override this for an advanced use case, consider using
pytree_dataclass
without subclassingStruct
.- Returns:
Children of the node, along with static metadata.
- final tree_flatten_with_keys() tuple[Sequence[tuple[Any, Any]], Any] [source]#
Flattens this tree node with keys.
See
jax.tree_util.register_pytree_with_keys_class
.This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation (and in particular, on
key_for_field
producing the JAX keypath keys for each field). If you must override this for an advanced use case, consider usingpytree_dataclass
without subclassingStruct
.- Returns:
(key, child)
pairs for the node, along with static metadata.
- final classmethod tree_unflatten(aux_data: Any, children: Sequence[Any]) Struct [source]#
Unflattens this tree node.
See
jax.tree_util.register_pytree_with_keys_class
.This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation. If you must override this for an advanced use case, consider using
pytree_dataclass
without subclassingStruct
.- Parameters:
aux_data – Auxiliary data, returned from the second argument of
tree_flatten_with_keys
(ortree_flatten
).children – Sequence of children from the first argument of
tree_flatten_with_keys
(ortree_flatten
).
- Returns:
An instance of the struct.
- treescope_color() str | tuple[str, str] [source]#
Computes a CSS color to display for this object in treescope.
This function can be overridden to change the color for a particular object in treescope, without having to register a new handler.
- Returns:
A CSS color string to use as a background/highlight color for this object. Alternatively, a tuple of (border, fill) CSS colors.