name_to_name_sharding#
- penzai.toolshed.sharding_util.name_to_name_sharding(tree: PyTreeOfNamedArrays, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[pz.nx.AxisName, str | tuple[str, ...]] | None = None, ignore_unnamed_arrays: bool = False, as_shape_dtype_struct: bool = False) PyTreeOfShardings[source]#
Shards a tree of
pz.nx.NamedArrayobjects based on their axis names.- Parameters:
tree – A PyTree of
pz.nx.NamedArrayinstances, with the same structure as the tree you want to shard. It is OK for the NamedArray instances to have invalid or missing data arrays; the data is not used.mesh – The
jax.sharding.Meshto shard the tree to.axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).
ignore_unnamed_arrays – Whether to ignore non-NamedArray leaves. If True, any leaf that is not a NamedArray will be given a
Nonesharding, usually indicating that JAX should infer a sharding. If False, aValueErrorwill be raised if any leaf is not a NamedArray.as_shape_dtype_struct – If True, instead of directly returning a PyTree of
NamedSharding, return a PyTree ofjax.ShapeDTypeStructwhere the.shardingattribute is theNamedSharding. This can be useful for building inputs toorbax.checkpoint, for instance.
- Returns:
A PyTree with the same structure as the input tree, but with all
pz.nx.NamedArrayinstances replaced with versions that havejax.sharding.NamedShardingleaves in place of their actual data arrays. This is suitable for passing as thein_shardingsorout_shardingsforjax.jit, or as the sharding forjax.device_put.