Struct#

class penzai.core.struct.Struct[source]#

Bases: object

Base class for penzai PyTree structures.

Struct is the common base class for most objects in penzai. Structs are both frozen dataclasses and JAX PyTree nodes, with their fields annotations specifying which attributes contain JAX-traversible subtrees or numeric data and which attributes do not.

Struct is heavily inspired by equinox’s equinox.Module, and works much in the same way. However, there are a few differences:

  • Every non-abstract Struct must be explicitly registered as a dataclass pytree using the decorator penzai.pytree_dataclass, so that readers of the code can tell the class’s semantics differ from that of an ordinary Python class.

  • The pytree_dataclass decorator supports additional configuration via keyword arguments, similar to the original dataclass decorator, and adds a few other features as well. In particular, by default attributes are NOT inherited from parent dataclasses, and __init__ is modified to allow easier assignment to immutable fields; see penzai.pytree_dataclass for details.

  • __init__ follows normal dataclass rules: an __init__ will be generated unless __init__ is defined or init=False is passed, similar to ordinary dataclass wrappers and in line with common typechecker expectations. However, to prevent accidentally overwriting their parent class’s __init__ instead of inheriting it, subclasses of classes with custom __init__ implementations must explicitly opt in to this behavior by setting overwrite_parent_init=True, or opt out with init=False. (Equinox modules instead try to always inherit __init__ in this case.)

  • Some convenient common methods for building, destructuring, and visualizing structs are defined by default.

  • Some equinox-specific features are not supported. Specifically, bound methods are not wrapped in Partial, and Equinox’s “wrapped modules” are not supported.

Methods

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

Inherited Methods

(expand to view inherited methods)

__init__()

final attributes_dict() dict[str, Any][source]#

Constructs a dictionary with all of the fields in the class.

The result of this should be passable back to from_attributes to rebuild (a copy of) the object.

Returns:

A dictionary containing all of the dataclass fields of the object.

final classmethod from_attributes(**field_values) T[source]#

Directly instantiates a struct given all of its fields.

Structs can override __init__ to have arbitrary custom behavior, but this may make it difficult to construct new instances of structs with particular field values. This function makes it possible to directly instantiate an instance of a struct with given attributes.

(Note: Overriding __init__ in a Struct subclass is usually discouraged.)

The main purpose of this method is to enable easier serialization and deserialization of structs. Callers of this method are responsible for maintaining any invariants expected by the class.

Parameters:

**field_values – Values for each of the struct’s fields.

Returns:

A new instance of the class.

key_for_field(field_name: str) Hashable[source]#

Generates a JAX PyTree key for a given field name.

This can be overridden if more control over JAX key paths is needed.

Parameters:

field_name – The field name to construct a key for.

Returns:

A hashable key to use in JAX PyTree paths.

final select() selectors.Selection[Struct][source]#

Wraps this struct in a selection, enabling functional-style mutations.

This is a convenience wrapper around selectors.select to enable easier selection, using syntax like:

struct.select().at(lambda b: b.foo[3].bar).apply(baz)
struct.select().at_instances_of(Sequential).at_children().apply(qux)

See documentation for selectors.Selection for supported attributes.

Returns:

A singleton selection containing this struct.

final tree_flatten() tuple[Sequence[Any], Any][source]#

Flattens this tree node.

See jax.tree_util.register_pytree_with_keys_class.

This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation. If you must override this for an advanced use case, consider using pytree_dataclass without subclassing Struct.

Returns:

Children of the node, along with static metadata.

final tree_flatten_with_keys() tuple[Sequence[tuple[Any, Any]], Any][source]#

Flattens this tree node with keys.

See jax.tree_util.register_pytree_with_keys_class.

This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation (and in particular, on key_for_field producing the JAX keypath keys for each field). If you must override this for an advanced use case, consider using pytree_dataclass without subclassing Struct.

Returns:

(key, child) pairs for the node, along with static metadata.

final classmethod tree_unflatten(aux_data: Any, children: Sequence[Any]) Struct[source]#

Unflattens this tree node.

See jax.tree_util.register_pytree_with_keys_class.

This method should not be overridden by subclasses, since struct-manipulation code should be able to rely on this implementation. If you must override this for an advanced use case, consider using pytree_dataclass without subclassing Struct.

Parameters:
Returns:

An instance of the struct.

treescope_color() str | tuple[str, str][source]#

Computes a CSS color to display for this object in treescope.

This function can be overridden to change the color for a particular object in treescope, without having to register a new handler.

Returns:

A CSS color string to use as a background/highlight color for this object. Alternatively, a tuple of (border, fill) CSS colors.