random_split

Contents

random_split#

penzai.core.named_axes.random_split(key: jax.Array | NamedArrayBase, named_shape: Mapping[AxisName, int] | Sequence[tuple[AxisName, int]]) NamedArray | NamedArrayView[source]#

Splits a PRNG key into a NamedArray of PRNG keys with the given names.

Parameters:
  • key – PRNG key to split. Can also be a NamedArray of keys with disjoint names from named_shape.

  • named_shape – Named shape for the result. If an unordered mapping, the keys will be sorted before splitting. To avoid this (e.g. for unsortable keys) you can pass a collections.OrderedDict or a sequence of (name, size) tuples.

Returns:

NamedArray or NamedArrayView with the given named shape, filled with unique PRNG keys.