CheckedSequential#
- class penzai.nn.grouping.CheckedSequential[source]#
Bases:
Layer
A group of layers to call sequentially, with known input/output types.
CheckedSequential
is a “typed” variant ofSequential
, which is annotated with input and output structures. The input and output structures will share state variables, which can be used to make assertions about the relationship between the shape of the inputs and the shape ouf the outputs.- Variables:
input_like (Any) – An input structure, represented as a PyTree of
pz.chk.ArraySpec
nodes. This defines the type of input this layer expects to receive. Passing anything else will raise an error.sublayers (list[penzai.nn.layer.Layer]) – A sequence of layers to call in order. These are usually pz.nn.Layer instances, but are allowed to be other types of callable PyTree as well.
output_like (Any) – An output structure, represented as a PyTree of
pz.chk.ArraySpec
nodes. This defines the type of input this layer will produce. Returining anything else will raise an error.
Methods
__init__
(input_like, output_like, sublayers)treescope_color
()__call__
(value, **side_inputs)Runs each of the sublayers in sequence.
Attributes
input_like
output_like
sublayers
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.