name_to_name_device_put#
- penzai.toolshed.sharding_util.name_to_name_device_put(tree: PyTreeOfNamedArrays, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = None) PyTreeOfNamedArrays [source]#
Shards a tree of
pz.nx.NamedArray
objects based on their axis names.- Parameters:
tree – A PyTree of NamedArrays.
mesh – The Mesh to shard the tree to.
axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).
- Returns:
A PyTree with the same structure as the input tree, but with all NamedArrays put onto the appropriate devices according to the mesh.