WithStatefulRandomKey#
- class penzai.deprecated.v1.data_effects.random.WithStatefulRandomKey[source]#
Bases:
EffectHandler
RandomEffect
handler that tracks a random seed as a local state.WithStatefulRandomKey
transforms RandomEffect effect into aLocalStateEffect
, allowing it to be statefully updated using the existing state manipulation features. It does not change the input or output types of the model, since the random state is managed externally.- Variables:
handler_id (effect_base.HandlerId) – The ID of this handler.
body (layer_base.LayerLike) – The layer that this handler wraps.
random_state (local_state.LocalStateEffect[jax.Array]) – The local state holding the current random key.
Methods
__init__
(handler_id, body, random_state)effect_protocol
()handling
(body, initial_key[, ...])Builds a
WithStatefulRandomKey
that handles effects in this layer.__call__
(argument)Attributes
handler_id
body
random_state
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
input_structure
()Returns the input structure of this layer.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
output_structure
()Returns the output structure of this layer.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()- classmethod handling(body: layer_base.LayerLike, initial_key: jax.Array, hole_predicate: Callable[[RandomRequest | TaggedRandomRequest], bool] = <function _is_untagged_hole>, state_category: Any = 'random', handler_id: str | None = None) WithStatefulRandomKey [source]#
Builds a
WithStatefulRandomKey
that handles effects in this layer.- Parameters:
body – The layer to wrap. Usually will contain random effects in the form of
RandomRequest
orTaggedRandomRequest
.initial_key – Initial key to use for the state.
hole_predicate – Callable that determines whether we should handle a given random effect hole. By default, handles all instances of
RandomRequest
but no instances ofTaggedRandomRequest
.state_category – Type to use when configuring the state effect.
handler_id – ID to use for the handler. If None, will be inferred.
- Returns:
A
WithStatefulRandomKey
handler wrappingbody
, with its random effect holes replaced with references to this handler (whenever allowed by the predicate).