RandomStream#

class penzai.core.random_stream.RandomStream[source]#

Bases: Struct

A stateful random stream object.

This object can be used to generate a sequence of random numbers inside a Penzai model. It uses a Variable to track its state.

Variables:
  • base_key (jax.Array) – The base key to use for this random stream. This does not change, and determines which sequence of random numbers will be generated.

  • offset (variables.StateVariable[int | jax.Array]) – The number of random numbers that have been generated so far.

Methods

__init__(base_key, offset)

from_base_key(base_key[, offset_label])

Returns a new random stream with the given base key.

next_key()

Returns the next key in the sequence, and advances the stream.

Attributes

base_key

offset

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

classmethod from_base_key(base_key: jax.Array, offset_label: variables.VariableLabel = 'random_stream_offset') RandomStream[source]#

Returns a new random stream with the given base key.

next_key() jax.Array[source]#

Returns the next key in the sequence, and advances the stream.