WithFrozenRandomState#

class penzai.data_effects.random.WithFrozenRandomState[source]#

Bases: EffectHandler

RandomEffect handler that uses a fixed random state.

WithFrozenRandomState can be used to freeze the random state of a model at a given point in time, allowing it to be deterministic and reproducible. It is most useful for debugging the behavior of a stochastic model while holding the random seed constant.

Variables:
  • handler_id (effect_base.HandlerId) – The ID of this handler.

  • body (layer_base.LayerLike) – The layer that this handler wraps.

  • random_key (jax.Array) – The constant random key to use.

  • starting_offset (int | jax.Array) – The starting offset at which to generate random numbers using the random key. This can be used to advance the random stream as if there were previous calls to next_key.

Methods

__init__(handler_id, body, random_key, ...)

effect_protocol()

handling(body, random_key[, ...])

Builds a WithFrozenRandomState that handles effects in this layer.

__call__(argument)

Attributes

handler_id

body

random_key

starting_offset

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

classmethod handling(body: layer_base.LayerLike, random_key: jax.Array, starting_offset: int = 0, hole_predicate: Callable[[RandomRequest | TaggedRandomRequest], bool] = <function _is_untagged_hole>, handler_id: str | None = None) WithStatefulRandomKey[source]#

Builds a WithFrozenRandomState that handles effects in this layer.

Parameters:
  • body – The layer to wrap. Usually will contain random effects in the form of RandomRequest or TaggedRandomRequest.

  • random_key – Initial key to use for the state.

  • starting_offset – Offset to use for the key.

  • hole_predicate – Callable that determines whether we should handle a given random effect hole. By default, handles all instances of RandomRequest but no instances of TaggedRandomRequest.

  • handler_id – ID to use for the handler. If None, will be inferred.

Returns:

A WithFrozenRandomState handler wrapping body, with its random effect holes replaced with references to this handler (whenever allowed by the predicate).