RandomStream#
- class penzai.core.random_stream.RandomStream[source]#
Bases:
Struct
A stateful random stream object.
This object can be used to generate a sequence of random numbers inside a Penzai model. It uses a Variable to track its state.
- Variables:
base_key (jax.Array) – The base key to use for this random stream. This does not change, and determines which sequence of random numbers will be generated.
offset (variables.StateVariable[int | jax.Array]) – The number of random numbers that have been generated so far.
Methods
__init__
(base_key, offset)from_base_key
(base_key[, offset_label])Returns a new random stream with the given base key.
next_key
()Returns the next key in the sequence, and advances the stream.
Attributes
base_key
offset
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- classmethod from_base_key(base_key: jax.Array, offset_label: variables.VariableLabel = 'random_stream_offset') RandomStream [source]#
Returns a new random stream with the given base key.