WithStatefulRandomKey#

class penzai.data_effects.random.WithStatefulRandomKey[source]#

Bases: EffectHandler

RandomEffect handler that tracks a random seed as a local state.

WithStatefulRandomKey transforms RandomEffect effect into a LocalStateEffect, allowing it to be statefully updated using the existing state manipulation features. It does not change the input or output types of the model, since the random state is managed externally.

Variables:
  • handler_id (effect_base.HandlerId) – The ID of this handler.

  • body (layer_base.LayerLike) – The layer that this handler wraps.

  • random_state (local_state.LocalStateEffect[jax.Array]) – The local state holding the current random key.

Methods

__init__(handler_id, body, random_state)

effect_protocol()

handling(body, initial_key[, ...])

Builds a WithStatefulRandomKey that handles effects in this layer.

__call__(argument)

Attributes

handler_id

body

random_state

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

classmethod handling(body: layer_base.LayerLike, initial_key: jax.Array, hole_predicate: Callable[[RandomRequest | TaggedRandomRequest], bool] = <function _is_untagged_hole>, state_category: Any = 'random', handler_id: str | None = None) WithStatefulRandomKey[source]#

Builds a WithStatefulRandomKey that handles effects in this layer.

Parameters:
  • body – The layer to wrap. Usually will contain random effects in the form of RandomRequest or TaggedRandomRequest.

  • initial_key – Initial key to use for the state.

  • hole_predicate – Callable that determines whether we should handle a given random effect hole. By default, handles all instances of RandomRequest but no instances of TaggedRandomRequest.

  • state_category – Type to use when configuring the state effect.

  • handler_id – ID to use for the handler. If None, will be inferred.

Returns:

A WithStatefulRandomKey handler wrapping body, with its random effect holes replaced with references to this handler (whenever allowed by the predicate).