initialize_parameters_sharded

initialize_parameters_sharded#

penzai.toolshed.sharding_util.initialize_parameters_sharded(model: Any, prng_key: jax.Array, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = None) Any[source]#

Initializes the parameters of a model, sharded according to a mesh.

Parameters:
  • model – A model whose parameters we should initialize. Should usually contain pz.nn.UninitializedParameter instances.

  • prng_key – Key to use to initialize parameters.

  • mesh – The Mesh to shard the tree to.

  • axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).

Returns:

An initialized version of model, with all uninitialized parameters replaced with initialized parameters, sharded according to the arguments.