order_like

Contents

order_like#

penzai.core.named_axes.order_like(value_tree: Any, reference_tree: Any)[source]#

Orders a tree of named arrays to match the structure of another tree.

This function takes two PyTrees and makes each NamedArray in value_tree have the same structure as the corresponding NamedArray in reference_tree. This can be used when passing NamedArrays through JAX transformations that require identical PyTree structures.

This is the equivalent of NamedArrayBase.order_like for trees of arrays. Leaves that are not NamedArrays or NamedArrayViews are left unchanged.

Parameters:
  • value_tree – The tree to order.

  • reference_tree – The tree to match. Must have the same structure as value_tree except that NamedArrays may have differently-ordered axes.

Returns:

A tree with the same exact PyTree structure as reference_tree, with array data from value_tree.