auto_order_types

auto_order_types#

Helper class to define automatic arbitrary orderings across different types.

Dictionaries in JAX need to have sortable keys in order to be manipulated with jax.tree_util or passed through JAX transformations. This makes it difficult to have a dictionary whose keys are different types. However, such patterns can be useful for storing multiple types of data in a single dictionary. In particular, Penzai requires slots and variables to have unique labels, and frequently uses these labels as keys in dictionaries. These labels can be made unique by ensuring that different categories of labels are of different types, but this requires those types to be mutually comparable so that they can pass through JAX transformations.

This module defines a class AutoOrderedAcrossTypes that defines __lt__, __gt__, __le__, and __ge__ methods for dataclasses so that:

  • comparisons within a single type are done using the ordinary dataclass rules (i.e. ordering like a tuple of their values),

  • comparisons between two different subclasses of AutoOrderedAcrossTypes are ordered arbitrarily, such that for two types A and B, either all instances of A are less than all instances of B or vice versa,

  • any subclass of AutoOrderedAcrossTypes is always greater than any string or ordinary tuple.

Classes

AutoOrderedAcrossTypes

Mixin to define an arbitrary total ordering across dataclass subclasses.