RandomStream#

class penzai.core.random_stream.RandomStream[source]#

Bases: object

Helper object to construct a stream of random numbers.

This object is unsafe to pass across a JAX transformation boundary, and should only be used locally. To ensure they do not escape the scope in which they were defined, RandomStreams must be used as a context manager, e.g.:

with RandomStream(key) as stream:
  # do something with stream.next_key()

# stream can no longer be used here
Variables:
  • base_key (jax.Array) – Base key used to generate the stream.

  • next_offset (int | jax.Array) – Offset to use when generating the next key.

  • state (Literal['pending', 'active', 'expired']) – Whether this random stream has “expired” and should no longer be used. Random streams expire once their context manager is exited.

Methods

__init__(base_key[, next_offset])

next_key()

Gets the next key from this stream, mutating the stream in place.

unsafe_mark_active()

Activates the random stream, returning itself for convenience.

__enter__()

Activates the random stream in a context.

__exit__(exc_type, exc_value, traceback)

Deactivates the random stream.

Attributes

next_offset

state

base_key

__enter__() RandomStream[source]#

Activates the random stream in a context.

__exit__(exc_type, exc_value, traceback)[source]#

Deactivates the random stream.

next_key() jax.Array[source]#

Gets the next key from this stream, mutating the stream in place.

unsafe_mark_active() RandomStream[source]#

Activates the random stream, returning itself for convenience.