combine#
- penzai.core.partitioning.combine(*partitions: Any) Any [source]#
Combines leaves from multiple partitions.
This function can be used to reverse the results of calling
partition
on a selector. It walks two or more PyTrees, and combines their leaves by detecting and replacing instances ofNotInThisPartition
. The partitions can haveNotInThisPartition
replacing any subtree, not just leaves.A common use case for this function is to recombine parts of an input model after splitting them to handle them differently in JAX transformations. See the documentation for
penzai.core.selectors.Selection.partition
for more details.This function is inspired by Equinox’s
eqx.combine
, but usesNotInThisPartition
as the sentinel instead of None, and supportsNotInThisPartition
at arbitrary subtree locations instead of only having None at leaves. (Partitioning is also somewhat less important in Penzai than in Equinox because all PyTree leaves are arraylike by convention; partitioning is only necessary when different parts of the tree need special treatment.)- Parameters:
*partitions – Partitions to combine. All partitions should have the same PyTree structure except that, for each PyTree leaf, exactly one of the input partitions actually has that leaf, and every other partition has
NotInThisPartition
sentinels in place of that leaf or one of its ancestors.- Returns:
Combined version of all partitions, which takes the concrete value from each partition instead of the
NotInThisPartition
sentinel.