Sequential

Sequential#

class penzai.nn.grouping.Sequential[source]#

Bases: Layer

A group of layers to call sequentially.

Sequential is one of the most common layer types to use in a penzai.nn model, since many networks can be written as the composition of a number of layers. However, you may prefer to use CheckedSequential if you can define in advance the structure of inputs and outputs your layer will accept.

A common pattern in penzai is:

  • subclass Sequential with a different layer name,

  • inherit __init__ and __call__ from Sequential,

  • define a classmethod (often called from_config) that constructs an instance of the subclass with its contents.

This allows the configuration and initialization logic for parts of a network (such as a self-attention layer) to be grouped in a single place, without affecting the later ability to interactively modify the resulting network.

Subclasses of Sequential are NOT allowed to override __call__. If a user has a subclass of Sequential, they should be able to assume it just calls each child in order. (If you need finer control, consider having a Sequential as a child attribute instead, or just duplicate the relevant logic for your own class.)

Variables:

sublayers (list[Callable[[Any], Any]]) – A sequence of layers to call in order. These are usually pz.Layer instances, but are allowed to be other types of callable PyTree as well.

Methods

__init__(sublayers)

treescope_color()

__call__(value)

Runs each of the sublayers in sequence.

Attributes

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

final __call__(value: Any) Any[source]#

Runs each of the sublayers in sequence.

Parameters:

value – The input to the layer.

Returns:

The output of the final sublayer.