partitioning#
Tags and utility functions for working with partitioned models.
Partitioning models allows different sets of leaves to be manipulated independently by JAX transformations, optimizers, and other logic. The partitioning logic in Penzai is inspired by the similar system in Equinox, but:
All partitions are created using selectors, in particular by calling
pz.core.selectors.Selection.partition, making it possible to partition based on many use-case-specific criteria.Nodes that are removed by the partition use the sentinel
NotInThisPartitionas a tag for parts that have been removed. This makes it obvious which parts of a partition have been removed. (By default, Equinox usesNonefor this, butNonemay also have other meanings in a PyTree structure.)Partitions should be combined with the
combinefunction in this module rather than Equinox, so thatNotInThisPartitionnodes are identified correctly.
Classes
Sentinel object that identifies subtrees removed by |
Functions
|
Combines leaves from multiple partitions. |