random_split#
- penzai.core.named_axes.random_split(key: jax.Array | NamedArrayBase, named_shape: Mapping[AxisName, int] | Sequence[tuple[AxisName, int]]) NamedArray | NamedArrayView [source]#
Splits a PRNG key into a
NamedArray
of PRNG keys with the given names.- Parameters:
key – PRNG key to split. Can also be a NamedArray of keys with disjoint names from
named_shape
.named_shape – Named shape for the result. If an unordered mapping, the keys will be sorted before splitting. To avoid this (e.g. for unsortable keys) you can pass a
collections.OrderedDict
or a sequence of (name, size) tuples.
- Returns:
NamedArray
orNamedArrayView
with the given named shape, filled with unique PRNG keys.