RandomStream#
- class penzai.core.random_stream.RandomStream[source]#
Bases:
StructA stateful random stream object.
This object can be used to generate a sequence of random numbers inside a Penzai model. It uses a Variable to track its state.
- Variables:
base_key (jax.Array) – The base key to use for this random stream. This does not change, and determines which sequence of random numbers will be generated.
offset (variables.StateVariable[int | jax.Array]) – The number of random numbers that have been generated so far.
Methods
__init__(base_key, offset)from_base_key(base_key[, offset_label])Returns a new random stream with the given base key.
next_key()Returns the next key in the sequence, and advances the stream.
Attributes
base_keyoffsetInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- classmethod from_base_key(base_key: jax.Array, offset_label: variables.VariableLabel = 'random_stream_offset') RandomStream[source]#
Returns a new random stream with the given base key.