EffectHandler#

class penzai.data_effects.effect_base.EffectHandler[source]#

Bases: Layer, ABC

A handler for a particular effect.

Subclasses of EffectHandler are responsible for replacing effect requests with effect references when constructed, and then replacing those references with concrete implementations when called. Subclasses of this class are by convention called “With{X}” where X is some description of the effect that occurs when the handler is called (e.g. “WithRandomStream”, “WithMutableLocalState”, “WithSideInputs”).

Handlers must define the attributes handler_id and body but may also define additional attributes as needed.

Implementers of new handlers should implement a constructor class method (usually not __init__ but an explicit class method called something like handling or from_submodel). This constructor class method should generate a new handler ID using infer_or_check_handler_id, set that as its handler ID, and replace all EffectRequest instances that it knows how to handle with HandledEffectRef instances that reference this handler ID. They should also override __call__ to create a temporary copy of the submodel body where each of the HandledEffectRef instances are replaced with some instance of EffectRuntimeImpl, then call that copy.

Handlers are free to modify the inputs or outputs of the submodel as needed, e.g. by accepting extra information as an input or returning additional information as output. However, the inputs should be in the form of a single Python argument (e.g. a tuple, dictionary, or dataclass) and similarly the output should be a single Python value. This ensures that multiple handlers can be nested together without conflicts.

Variables:
  • handler_id (HandlerId) – The ID of this handler.

  • body (layer_base.LayerLike) – The layer that this handler wraps.

Methods

__init__(handler_id, body)

effect_protocol()

Returns the effect protocol(s) that this handler handles.

treescope_color()

Attributes

handler_id

body

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

__call__(argument, /)

Abstract call method for a layer.

abstract classmethod effect_protocol() type[Any] | Collection[type[Any]] | None[source]#

Returns the effect protocol(s) that this handler handles.

Advanced handlers are allowed to handle multiple effects, and the specific effect interfaces are determined by the references inside body. This method is used primarily to aid debugging and visualization.

Returns:

A single effect protocol if applicable. Can also return a collection of protocols or None if the handler is not associated with a specific effect.