derive_param_key

derive_param_key#

penzai.nn.parameters.derive_param_key(base_rng: jax.Array, name: str) jax.Array[source]#

Derives a PRNG key for a parameter from a base key and a name.

Parameters:
  • base_rng – The base PRNG key.

  • name – The name of the parameter.

Returns:

A unique PRNG key for the parameter with this name.