pytree_dataclass#
- penzai.core.struct.pytree_dataclass(cls: type[Any] | None = None, /, *, has_implicitly_inherited_fields: bool = False, use_mutable_proxy_in_init: bool = True, overwrite_parent_init: bool = False, init: bool = True, repr: bool | Literal['auto'] = 'auto', eq: bool = True, order: bool = False, match_args: bool = False, kw_only: bool = False) type[Any] | Callable[[type[Any]], type[Any]] [source]#
Decorator for constructing a frozen PyTree dataclass.
- This decorator:
transforms the provided class into a frozen dataclass,
registers it with JAX as a PyTree node class with keys (but doesn’t actually define the required methods),
runs some safety checks to avoid common dataclass pitfalls,
optionally, wraps any user-provided __init__ so that it can set its attributes normally even though the dataclass is frozen.
Registration with JAX uses
jax.tree_util.register_pytree_with_keys_class
. This means thatcls
must define an instance methodtree_flatten_with_keys
and a class methodtree_unflatten
, as described in the JAX documentation. (If applyingpytree_dataclass
to a subclass ofStruct
, implementations of these methods are provided for you, and shouldn’t be overridden.)If
has_implicitly_inherited_fields
is False, this decorator prevents a common pitfall with dataclass inheritance, by making sure that the list of attributes listed on the class exactly matches the list of fields inferred bydataclasses
. This protects against some unintuitive behavior ofdataclasses.dataclass
: by default dataclasses inherit un-annotated fields from parent dataclasses and may have annotated fields re-ordered to match parent classes also.If
overwrite_parent_init
is False, we try to prevent a rare but tricky footgun of dataclass inheritance: if a parent class defines an __init__ that was not generated bydataclasses
, dataclasses will happily overwrite __init__ with a generated one, even though the caller may be expecting to inherit it. In particular, we raise an error ifoverwrite_parent_init
is False,init
is True, the class does not define__init__
itself, and the inherited __init__ was neitherobject.__init__
nor an implementation generated bypytree_dataclass
. This error can be silenced either by settingoverwrite_parent_init=True
orinit=False
, depending on which behavior the author intended, or by defining __init__ directly.The argument
use_mutable_proxy_in_init
determines whether __init__ should be modified to allow in-place mutation of the dataclass fields, inspired by a similar feature in equinox. This is implemented by constructing a separate mutable subclass of the class (stored ascls._MutableInitProxy
), using the user-specified __init__ for that subclass, and then copying the fields (and only the fields) from this mutable proxy type into the original object. This transformation is only done on manually-provided __init__ implementations, not for the one generated bydataclasses.dataclass
. (Although a bit indirect, this is in some ways safer than writing __init__ normally for a frozen dataclass, since that can expose a partially-constructed dataclass type that won’t work properly with JAX.)Note that this transformation stores penzai-specific arguments in the attribute
cls.__penzai_pytree_dataclass_info__
, which is also used to check whether a class has been wrapped in this decorator already.- Parameters:
cls – The class to wrap. If provided, transforms this class and returns a transformed copy. If not provided, returns a decorator which can be applied to a class, similar to the ordinary
dataclasses
decorator.has_implicitly_inherited_fields – Whether this dataclass is explicitly opting into inheriting dataclass fields from its parent class(es). Usually, classes wrapped by @pytree_dataclass are encouraged to explicitly list out all fields, including inherited ones.
use_mutable_proxy_in_init – Whether to wrap any user-provided __init__ so that assignments to
self
are possible inside it. See more detailed description above.overwrite_parent_init – Whether it’s OK to overwrite an inherited __init__ defined in a parent class, even if that __init__ was not generated by
pytree_dataclass
(and was not justobject.__init__
).init – Whether to generate __init__; see
dataclasses.dataclass
. You should usually set this to True unless you want to inherit a specific __init__ implementation from a parent class. Ignored if your class defines __init__ directly.repr – Whether to generate __repr__. If “auto”, generates a __repr__ unless this is an instance of
Struct
, which already uses Treescope to represent the object. Seedataclasses.dataclass
.eq – Whether to generate __eq__; see
dataclasses.dataclass
.order – Whether to generate ordering methods; see
dataclasses.dataclass
.match_args – Whether to define __match_args__; see
dataclasses.dataclass
.kw_only – Whether to declare fields as keyword-only; see
dataclasses.dataclass
.
- Returns:
A transformed version of
cls
if provided, or a decorator which can be applied to a class to transform it.- Raises:
ValueError – If this class is already a pytree dataclass, or if it already has values for reserved properties.
PyTreeDataclassSafetyError – If
strict_fields
is True and the class does not explicitly list its fields in the correct order, or ifoverwrite_parent_init
is False but we’re about to overwrite a custom __init__.