partitioning#
Tags and utility functions for working with partitioned models.
Partitioning models allows different sets of leaves to be manipulated independently by JAX transformations, optimizers, and other logic. The partitioning logic in Penzai is inspired by the similar system in Equinox, but:
All partitions are created using selectors, in particular by calling
pz.core.selectors.Selection.partition
, making it possible to partition based on many use-case-specific criteria.Nodes that are removed by the partition use the sentinel
NotInThisPartition
as a tag for parts that have been removed. This makes it obvious which parts of a partition have been removed. (By default, Equinox usesNone
for this, butNone
may also have other meanings in a PyTree structure.)Partitions should be combined with the
combine
function in this module rather than Equinox, so thatNotInThisPartition
nodes are identified correctly.
Classes
Sentinel object that identifies subtrees removed by |
Functions
|
Combines leaves from multiple partitions. |