order_like#
- penzai.core.named_axes.order_like(value_tree: Any, reference_tree: Any)[source]#
Orders a tree of named arrays to match the structure of another tree.
This function takes two PyTrees and makes each NamedArray in
value_tree
have the same structure as the corresponding NamedArray inreference_tree
. This can be used when passing NamedArrays through JAX transformations that require identical PyTree structures.This is the equivalent of
NamedArrayBase.order_like
for trees of arrays. Leaves that are not NamedArrays or NamedArrayViews are left unchanged.- Parameters:
value_tree – The tree to order.
reference_tree – The tree to match. Must have the same structure as
value_tree
except that NamedArrays may have differently-ordered axes.
- Returns:
A tree with the same exact PyTree structure as
reference_tree
, with array data fromvalue_tree
.