order_like#
- penzai.core.named_axes.order_like(value_tree: Any, reference_tree: Any)[source]#
Orders a tree of named arrays to match the structure of another tree.
This function takes two PyTrees and makes each NamedArray in
value_treehave the same structure as the corresponding NamedArray inreference_tree. This can be used when passing NamedArrays through JAX transformations that require identical PyTree structures.This is the equivalent of
NamedArrayBase.order_likefor trees of arrays. Leaves that are not NamedArrays or NamedArrayViews are left unchanged.- Parameters:
value_tree – The tree to order.
reference_tree – The tree to match. Must have the same structure as
value_treeexcept that NamedArrays may have differently-ordered axes.
- Returns:
A tree with the same exact PyTree structure as
reference_tree, with array data fromvalue_tree.