pytree_dataclass

pytree_dataclass#

penzai.core.struct.pytree_dataclass(cls: type[Any] | None = None, /, *, has_implicitly_inherited_fields: bool = False, use_mutable_proxy_in_init: bool = True, overwrite_parent_init: bool = False, init: bool = True, repr: bool | Literal['auto'] = 'auto', eq: bool = True, order: bool = False, match_args: bool = False, kw_only: bool = False) type[Any] | Callable[[type[Any]], type[Any]][source]#

Decorator for constructing a frozen PyTree dataclass.

This decorator:
  • transforms the provided class into a frozen dataclass,

  • registers it with JAX as a PyTree node class with keys (but doesn’t actually define the required methods),

  • runs some safety checks to avoid common dataclass pitfalls,

  • optionally, wraps any user-provided __init__ so that it can set its attributes normally even though the dataclass is frozen.

Registration with JAX uses jax.tree_util.register_pytree_with_keys_class. This means that cls must define an instance method tree_flatten_with_keys and a class method tree_unflatten, as described in the JAX documentation. (If applying pytree_dataclass to a subclass of Struct, implementations of these methods are provided for you, and shouldn’t be overridden.)

If has_implicitly_inherited_fields is False, this decorator prevents a common pitfall with dataclass inheritance, by making sure that the list of attributes listed on the class exactly matches the list of fields inferred by dataclasses. This protects against some unintuitive behavior of dataclasses.dataclass: by default dataclasses inherit un-annotated fields from parent dataclasses and may have annotated fields re-ordered to match parent classes also.

If overwrite_parent_init is False, we try to prevent a rare but tricky footgun of dataclass inheritance: if a parent class defines an __init__ that was not generated by dataclasses, dataclasses will happily overwrite __init__ with a generated one, even though the caller may be expecting to inherit it. In particular, we raise an error if overwrite_parent_init is False, init is True, the class does not define __init__ itself, and the inherited __init__ was neither object.__init__ nor an implementation generated by pytree_dataclass. This error can be silenced either by setting overwrite_parent_init=True or init=False, depending on which behavior the author intended, or by defining __init__ directly.

The argument use_mutable_proxy_in_init determines whether __init__ should be modified to allow in-place mutation of the dataclass fields, inspired by a similar feature in equinox. This is implemented by constructing a separate mutable subclass of the class (stored as cls._MutableInitProxy), using the user-specified __init__ for that subclass, and then copying the fields (and only the fields) from this mutable proxy type into the original object. This transformation is only done on manually-provided __init__ implementations, not for the one generated by dataclasses.dataclass. (Although a bit indirect, this is in some ways safer than writing __init__ normally for a frozen dataclass, since that can expose a partially-constructed dataclass type that won’t work properly with JAX.)

Note that this transformation stores penzai-specific arguments in the attribute cls.__penzai_pytree_dataclass_info__, which is also used to check whether a class has been wrapped in this decorator already.

Parameters:
  • cls – The class to wrap. If provided, transforms this class and returns a transformed copy. If not provided, returns a decorator which can be applied to a class, similar to the ordinary dataclasses decorator.

  • has_implicitly_inherited_fields – Whether this dataclass is explicitly opting into inheriting dataclass fields from its parent class(es). Usually, classes wrapped by @pytree_dataclass are encouraged to explicitly list out all fields, including inherited ones.

  • use_mutable_proxy_in_init – Whether to wrap any user-provided __init__ so that assignments to self are possible inside it. See more detailed description above.

  • overwrite_parent_init – Whether it’s OK to overwrite an inherited __init__ defined in a parent class, even if that __init__ was not generated by pytree_dataclass (and was not just object.__init__).

  • init – Whether to generate __init__; see dataclasses.dataclass. You should usually set this to True unless you want to inherit a specific __init__ implementation from a parent class. Ignored if your class defines __init__ directly.

  • repr – Whether to generate __repr__. If “auto”, generates a __repr__ unless this is an instance of Struct, which already uses Treescope to represent the object. See dataclasses.dataclass.

  • eq – Whether to generate __eq__; see dataclasses.dataclass.

  • order – Whether to generate ordering methods; see dataclasses.dataclass.

  • match_args – Whether to define __match_args__; see dataclasses.dataclass.

  • kw_only – Whether to declare fields as keyword-only; see dataclasses.dataclass.

Returns:

A transformed version of cls if provided, or a decorator which can be applied to a class to transform it.

Raises:
  • ValueError – If this class is already a pytree dataclass, or if it already has values for reserved properties.

  • PyTreeDataclassSafetyError – If strict_fields is True and the class does not explicitly list its fields in the correct order, or if overwrite_parent_init is False but we’re about to overwrite a custom __init__.