name_to_name_sharding#
- penzai.toolshed.sharding_util.name_to_name_sharding(tree: PyTreeOfNamedArrays, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[pz.nx.AxisName, str | tuple[str, ...]] | None = None, ignore_unnamed_arrays: bool = False, as_shape_dtype_struct: bool = False) PyTreeOfShardings [source]#
Shards a tree of
pz.nx.NamedArray
objects based on their axis names.- Parameters:
tree – A PyTree of
pz.nx.NamedArray
instances, with the same structure as the tree you want to shard. It is OK for the NamedArray instances to have invalid or missing data arrays; the data is not used.mesh – The
jax.sharding.Mesh
to shard the tree to.axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).
ignore_unnamed_arrays – Whether to ignore non-NamedArray leaves. If True, any leaf that is not a NamedArray will be given a
None
sharding, usually indicating that JAX should infer a sharding. If False, aValueError
will be raised if any leaf is not a NamedArray.as_shape_dtype_struct – If True, instead of directly returning a PyTree of
NamedSharding
, return a PyTree ofjax.ShapeDTypeStruct
where the.sharding
attribute is theNamedSharding
. This can be useful for building inputs toorbax.checkpoint
, for instance.
- Returns:
A PyTree with the same structure as the input tree, but with all
pz.nx.NamedArray
instances replaced with versions that havejax.sharding.NamedSharding
leaves in place of their actual data arrays. This is suitable for passing as thein_shardings
orout_shardings
forjax.jit
, or as the sharding forjax.device_put
.