sharded_init#
- penzai.toolshed.sharding_util.sharded_init(initializer: Callable[..., Any], *init_args, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = None, **init_kwargs) → Any[source]#
Initializes a model, with constants and variables sharded based on a mesh.
- Parameters:
initializer – The initializer to call. Should return a PyTree of arrays, Parameters, and StateVariables. All of the Parameters and StateVariables should be new variables created by the initializer.
*init_args – Positional arguments to pass to the initializer.
mesh – The
Mesh
to shard the tree to.axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).
**init_kwargs – Keyword arguments to pass to the initializer.
- Returns:
An initialized version of
model
, with all arrays, Parameters, and StateVariables sharded according to the mesh. Parameters and StateVariables will not share their states with any other variables defined outside the initializer.