initialize_parameters_sharded#
- penzai.deprecated.v1.toolshed.sharding_util.initialize_parameters_sharded(model: Any, prng_key: jax.Array, mesh: jax.sharding.Mesh, axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = None) Any [source]#
Initializes the parameters of a model, sharded according to a mesh.
- Parameters:
model – A model whose parameters we should initialize. Should usually contain
pz.nn.UninitializedParameter
instances.prng_key – Key to use to initialize parameters.
mesh – The
Mesh
to shard the tree to.axis_name_to_mesh_name – A mapping from array axis names to mesh axis names. If an axis name is not present, that axis will not be sharded. If a mesh axis name is a tuple, the corresponding axis will be sharded to multiple mesh axes. If this dictionary is not provided, it will be inferred as an “identity” mapping, where each axis is sharded to a mesh axis with the same name (if present).
- Returns:
An initialized version of
model
, with all uninitialized parameters replaced with initialized parameters, sharded according to the arguments.