sharded_init

Contents

sharded_init#

penzai.toolshed.sharding_util.sharded_init(initializer: Callable[..., Any], *init_args, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = None, **init_kwargs) Any[source]#

Initializes a model, with constants and variables sharded based on a mesh.

Parameters:
  • initializer – The initializer to call. Should return a PyTree of arrays, Parameters, and StateVariables. All of the Parameters and StateVariables should be new variables created by the initializer.

  • *init_args – Positional arguments to pass to the initializer.

  • mesh – The Mesh to shard the tree to.

  • axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).

  • **init_kwargs – Keyword arguments to pass to the initializer.

Returns:

An initialized version of model, with all arrays, Parameters, and StateVariables sharded according to the mesh. Parameters and StateVariables will not share their states with any other variables defined outside the initializer.