name_to_name_device_put

name_to_name_device_put#

penzai.toolshed.sharding_util.name_to_name_device_put(tree: PyTreeOfNamedArrays, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = None) PyTreeOfNamedArrays[source]#

Shards a tree of pz.nx.NamedArray objects based on their axis names.

Parameters:
  • tree – A PyTree of NamedArrays.

  • mesh – The Mesh to shard the tree to.

  • axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).

Returns:

A PyTree with the same structure as the input tree, but with all NamedArrays put onto the appropriate devices according to the mesh.