WithStatefulRandomKey#
- class penzai.deprecated.v1.data_effects.random.WithStatefulRandomKey[source]#
Bases:
EffectHandlerRandomEffecthandler that tracks a random seed as a local state.WithStatefulRandomKeytransforms RandomEffect effect into aLocalStateEffect, allowing it to be statefully updated using the existing state manipulation features. It does not change the input or output types of the model, since the random state is managed externally.- Variables:
handler_id (effect_base.HandlerId) – The ID of this handler.
body (layer_base.LayerLike) – The layer that this handler wraps.
random_state (local_state.LocalStateEffect[jax.Array]) – The local state holding the current random key.
Methods
__init__(handler_id, body, random_state)effect_protocol()handling(body, initial_key[, ...])Builds a
WithStatefulRandomKeythat handles effects in this layer.__call__(argument)Attributes
handler_idbodyrandom_stateInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
input_structure()Returns the input structure of this layer.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
output_structure()Returns the output structure of this layer.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()- classmethod handling(body: layer_base.LayerLike, initial_key: jax.Array, hole_predicate: Callable[[RandomRequest | TaggedRandomRequest], bool] = <function _is_untagged_hole>, state_category: Any = 'random', handler_id: str | None = None) WithStatefulRandomKey[source]#
Builds a
WithStatefulRandomKeythat handles effects in this layer.- Parameters:
body – The layer to wrap. Usually will contain random effects in the form of
RandomRequestorTaggedRandomRequest.initial_key – Initial key to use for the state.
hole_predicate – Callable that determines whether we should handle a given random effect hole. By default, handles all instances of
RandomRequestbut no instances ofTaggedRandomRequest.state_category – Type to use when configuring the state effect.
handler_id – ID to use for the handler. If None, will be inferred.
- Returns:
A
WithStatefulRandomKeyhandler wrappingbody, with its random effect holes replaced with references to this handler (whenever allowed by the predicate).