CollectingSideOutputs#

class penzai.data_effects.side_output.CollectingSideOutputs[source]#

Bases: EffectHandler

SideOutput handler that collects all side outputs into a list.

CollectingSideOutputs takes the same arguments as the wrapped layer, and returns the same output, but also returns a list of all side outputs written by the handled side output effect objects. Each side output is returned as a tuple containing the keypath, tag, and value.

Methods

__init__(handler_id, body)

effect_protocol()

handling(body[, tag, tag_predicate, handler_id])

Builds a CollectingSideOutputs that handles effects in this layer.

input_structure()

output_structure()

__call__(argument)

Attributes

handler_id

body

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

classmethod handling(body: layer_base.LayerLike, tag: SideOutputTag | None = None, tag_predicate: Callable[[SideOutputTag], bool] | None = None, handler_id: str | None = None) CollectingSideOutputs[source]#

Builds a CollectingSideOutputs that handles effects in this layer.

Parameters:
  • body – The layer to wrap. Usually will contain SideOutputRequest nodes.

  • tag – A tag to collect. If None, defers to tag_predicate.

  • tag_predicate – A predicate to use to select which side outputs to collect. Should return True for tags to collect. If neither tag nor tag_predicate is specified, all side outputs are collected.

  • handler_id – ID to use for the handler. If None, will be inferred.

Returns:

A CollectingSideOutputs handler wrapping body, with its side output holes with the specified tasg replaced with references to this handler.