Sequential

Sequential#

class penzai.nn.grouping.Sequential[source]#

Bases: Layer

A group of layers to call sequentially.

Sequential is one of the most common layer types to use in a penzai.nn model, since many networks can be written as the composition of a number of layers. However, you may prefer to use CheckedSequential if you can define in advance the structure of inputs and outputs your layer will accept.

A common pattern in penzai is:

  • subclass Sequential with a different layer name,

  • inherit __init__ and __call__ from Sequential,

  • define a classmethod (often called from_config) that constructs an instance of the subclass with its contents.

This allows the configuration and initialization logic for parts of a network (such as a self-attention layer) to be grouped in a single place, without affecting the later ability to interactively modify the resulting network.

Subclasses of Sequential are NOT allowed to override __call__. If a user has a subclass of Sequential, they should be able to assume it just calls each child in order. (If you need finer control, consider having a Sequential as a child attribute instead, or just duplicate the relevant logic for your own class.)

Variables:

sublayers (list[penzai.nn.layer.Layer]) – A sequence of layers to call in order. These are usually pz.nn.Layer instances, but are allowed to be other types of callable PyTree as well.

Methods

__init__(sublayers)

treescope_color()

__call__(value, **side_inputs)

Runs each of the sublayers in sequence.

Attributes

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

final __call__(value: Any, **side_inputs) Any[source]#

Runs each of the sublayers in sequence.

Parameters:
  • value – The input to the first sublayer.

  • **side_inputs – The side inputs for all sublayers.

Returns:

The output of the final sublayer.