name_to_name_sharding

name_to_name_sharding#

penzai.toolshed.sharding_util.name_to_name_sharding(tree: PyTreeOfNamedArrays, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[pz.nx.AxisName, str | tuple[str, ...]] | None = None, ignore_unnamed_arrays: bool = False, as_shape_dtype_struct: bool = False) PyTreeOfShardings[source]#

Shards a tree of pz.nx.NamedArray objects based on their axis names.

Parameters:
  • tree – A PyTree of pz.nx.NamedArray instances, with the same structure as the tree you want to shard. It is OK for the NamedArray instances to have invalid or missing data arrays; the data is not used.

  • mesh – The jax.sharding.Mesh to shard the tree to.

  • axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).

  • ignore_unnamed_arrays – Whether to ignore non-NamedArray leaves. If True, any leaf that is not a NamedArray will be given a None sharding, usually indicating that JAX should infer a sharding. If False, a ValueError will be raised if any leaf is not a NamedArray.

  • as_shape_dtype_struct – If True, instead of directly returning a PyTree of NamedSharding, return a PyTree of jax.ShapeDTypeStruct where the .sharding attribute is the NamedSharding. This can be useful for building inputs to orbax.checkpoint, for instance.

Returns:

A PyTree with the same structure as the input tree, but with all pz.nx.NamedArray instances replaced with versions that have jax.sharding.NamedSharding leaves in place of their actual data arrays. This is suitable for passing as the in_shardings or out_shardings for jax.jit, or as the sharding for jax.device_put.