sharding_util

sharding_util#

Utilities for working with sharded arrays and parameters in Penzai.

Classes

ConstrainSharding

A layer that constrains the sharding of a tree of arrays.

ConstrainShardingByName

A layer that constrains the sharding of a tree of NamedArrays by name.

Functions

initialize_parameters_sharded(model, ...[, ...])

Initializes the parameters of a model, sharded according to a mesh.

name_to_name_device_put(tree, mesh[, ...])

Shards a tree of pz.nx.NamedArray objects based on their axis names.

name_to_name_sharding(tree, mesh[, ...])

Shards a tree of pz.nx.NamedArray objects based on their axis names.