Residual

Residual#

class penzai.nn.combinators.Residual[source]#

Bases: Layer

A residual block, which runs its sublayers then adds the input.

Residual blocks add additional data-flow paths called “skip connections”, wherein the input to the residual block is saved, and the output of the sublayers is treated as a “residual” to add back to the input. When many residual blocks are run in order, this produces a “residual stream”, with each block reading from the stream and then making an additive write to it.

Residual blocks have non-linear data flow, but in a fairly straightforward way. This pattern can be factored out into a block so that it can be expressed consistently in more complex models.

Variables:

delta (layer_base.LayerLike) – A block to run and add its output to its input.

Methods

__init__(delta)

__call__(value)

Runs each of the sublayers in sequence, then adds back the original input.

Attributes

delta

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(value: Any) Any[source]#

Runs each of the sublayers in sequence, then adds back the original input.

Parameters:

value – The input to the block.

Returns:

The sum of the input to the residual block and the output of the child.