CheckedSequential

CheckedSequential#

class penzai.nn.grouping.CheckedSequential[source]#

Bases: Layer

A group of layers to call sequentially, with known input/output types.

CheckedSequential is a “typed” variant of Sequential, which is annotated with input and output structures. The input and output structures will share state variables, which can be used to make assertions about the relationship between the shape of the inputs and the shape ouf the outputs.

Variables:
  • input_like (Any) – An input structure, represented as a PyTree of pz.chk.ArraySpec nodes. This defines the type of input this layer expects to receive. Passing anything else will raise an error.

  • sublayers (list[Callable[[Any], Any]]) – A sequence of layers to call in order. These are usually pz.Layer instances, but are allowed to be other types of callable PyTree as well.

  • output_like (Any) – An output structure, represented as a PyTree of pz.chk.ArraySpec nodes. This defines the type of input this layer will produce. Returining anything else will raise an error.

Methods

__init__(input_like, output_like, sublayers)

input_structure()

output_structure()

treescope_color()

__call__(value)

Runs each of the sublayers in sequence.

Attributes

input_like

output_like

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

final __call__(value: Any) Any[source]#

Runs each of the sublayers in sequence.

Parameters:

value – The input to the layer.

Returns:

The output of the final sublayer.