partitioning

partitioning#

Tags and utility functions for working with partitioned models.

Partitioning models allows different sets of leaves to be manipulated independently by JAX transformations, optimizers, and other logic. The partitioning logic in Penzai is inspired by the similar system in Equinox, but:

  • All partitions are created using selectors, in particular by calling pz.core.selectors.Selection.partition, making it possible to partition based on many use-case-specific criteria.

  • Nodes that are removed by the partition use the sentinel NotInThisPartition as a tag for parts that have been removed. This makes it obvious which parts of a partition have been removed. (By default, Equinox uses None for this, but None may also have other meanings in a PyTree structure.)

  • Partitions should be combined with the combine function in this module rather than Equinox, so that NotInThisPartition nodes are identified correctly.

Classes

NotInThisPartition

Sentinel object that identifies subtrees removed by partition.

Functions

combine(*partitions)

Combines leaves from multiple partitions.