combine

Contents

combine#

penzai.core.partitioning.combine(*partitions: Any) Any[source]#

Combines leaves from multiple partitions.

This function can be used to reverse the results of calling partition on a selector. It walks two or more PyTrees, and combines their leaves by detecting and replacing instances of NotInThisPartition. The partitions can have NotInThisPartition replacing any subtree, not just leaves.

A common use case for this function is to recombine parts of an input model after splitting them to handle them differently in JAX transformations. See the documentation for penzai.core.selectors.Selection.partition for more details.

This function is inspired by Equinox’s eqx.combine, but uses NotInThisPartition as the sentinel instead of None, and supports NotInThisPartition at arbitrary subtree locations instead of only having None at leaves. (Partitioning is also somewhat less important in Penzai than in Equinox because all PyTree leaves are arraylike by convention; partitioning is only necessary when different parts of the tree need special treatment.)

Parameters:

*partitions – Partitions to combine. All partitions should have the same PyTree structure except that, for each PyTree leaf, exactly one of the input partitions actually has that leaf, and every other partition has NotInThisPartition sentinels in place of that leaf or one of its ancestors.

Returns:

Combined version of all partitions, which takes the concrete value from each partition instead of the NotInThisPartition sentinel.