pytree_dataclass#
- penzai.core.struct.pytree_dataclass(cls: type[Any] | None = None, /, *, has_implicitly_inherited_fields: bool = False, use_mutable_proxy_in_init: bool = True, overwrite_parent_init: bool = False, init: bool = True, repr: bool | Literal['auto'] = 'auto', eq: bool = True, order: bool = False, match_args: bool = False, kw_only: bool = False) type[Any] | Callable[[type[Any]], type[Any]][source]#
Decorator for constructing a frozen PyTree dataclass.
- This decorator:
transforms the provided class into a frozen dataclass,
registers it with JAX as a PyTree node class with keys (but doesn’t actually define the required methods),
runs some safety checks to avoid common dataclass pitfalls,
optionally, wraps any user-provided __init__ so that it can set its attributes normally even though the dataclass is frozen.
Registration with JAX uses
jax.tree_util.register_pytree_with_keys_class. This means thatclsmust define an instance methodtree_flatten_with_keysand a class methodtree_unflatten, as described in the JAX documentation. (If applyingpytree_dataclassto a subclass ofStruct, implementations of these methods are provided for you, and shouldn’t be overridden.)If
has_implicitly_inherited_fieldsis False, this decorator prevents a common pitfall with dataclass inheritance, by making sure that the list of attributes listed on the class exactly matches the list of fields inferred bydataclasses. This protects against some unintuitive behavior ofdataclasses.dataclass: by default dataclasses inherit un-annotated fields from parent dataclasses and may have annotated fields re-ordered to match parent classes also.If
overwrite_parent_initis False, we try to prevent a rare but tricky footgun of dataclass inheritance: if a parent class defines an __init__ that was not generated bydataclasses, dataclasses will happily overwrite __init__ with a generated one, even though the caller may be expecting to inherit it. In particular, we raise an error ifoverwrite_parent_initis False,initis True, the class does not define__init__itself, and the inherited __init__ was neitherobject.__init__nor an implementation generated bypytree_dataclass. This error can be silenced either by settingoverwrite_parent_init=Trueorinit=False, depending on which behavior the author intended, or by defining __init__ directly.The argument
use_mutable_proxy_in_initdetermines whether __init__ should be modified to allow in-place mutation of the dataclass fields, inspired by a similar feature in equinox. This is implemented by constructing a separate mutable subclass of the class (stored ascls._MutableInitProxy), using the user-specified __init__ for that subclass, and then copying the fields (and only the fields) from this mutable proxy type into the original object. This transformation is only done on manually-provided __init__ implementations, not for the one generated bydataclasses.dataclass. (Although a bit indirect, this is in some ways safer than writing __init__ normally for a frozen dataclass, since that can expose a partially-constructed dataclass type that won’t work properly with JAX.)Note that this transformation stores penzai-specific arguments in the attribute
cls.__penzai_pytree_dataclass_info__, which is also used to check whether a class has been wrapped in this decorator already.- Parameters:
cls – The class to wrap. If provided, transforms this class and returns a transformed copy. If not provided, returns a decorator which can be applied to a class, similar to the ordinary
dataclassesdecorator.has_implicitly_inherited_fields – Whether this dataclass is explicitly opting into inheriting dataclass fields from its parent class(es). Usually, classes wrapped by @pytree_dataclass are encouraged to explicitly list out all fields, including inherited ones.
use_mutable_proxy_in_init – Whether to wrap any user-provided __init__ so that assignments to
selfare possible inside it. See more detailed description above.overwrite_parent_init – Whether it’s OK to overwrite an inherited __init__ defined in a parent class, even if that __init__ was not generated by
pytree_dataclass(and was not justobject.__init__).init – Whether to generate __init__; see
dataclasses.dataclass. You should usually set this to True unless you want to inherit a specific __init__ implementation from a parent class. Ignored if your class defines __init__ directly.repr – Whether to generate __repr__. If “auto”, generates a __repr__ unless this is an instance of
Struct, which already uses Treescope to represent the object. Seedataclasses.dataclass.eq – Whether to generate __eq__; see
dataclasses.dataclass.order – Whether to generate ordering methods; see
dataclasses.dataclass.match_args – Whether to define __match_args__; see
dataclasses.dataclass.kw_only – Whether to declare fields as keyword-only; see
dataclasses.dataclass.
- Returns:
A transformed version of
clsif provided, or a decorator which can be applied to a class to transform it.- Raises:
ValueError – If this class is already a pytree dataclass, or if it already has values for reserved properties.
PyTreeDataclassSafetyError – If
strict_fieldsis True and the class does not explicitly list its fields in the correct order, or ifoverwrite_parent_initis False but we’re about to overwrite a custom __init__.