Layer#
- class penzai.deprecated.v1.core.layer.Layer[source]#
-
Abstract base class for neural network layers and other 1-arg callables.
Methods
Returns the input structure of this layer.
Returns the output structure of this layer.
__call__
(argument, /)Abstract call method for a layer.
Inherited Methods
(expand to view inherited methods)
__init__
()attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- abstract __call__(argument: Any, /) Any [source]#
Abstract call method for a layer.
Layers are submodels that take a single input and produce a single output. By convention, almost all model components in a Penzai model are instances of Layer, making it possible to easily compose them with other layers and wrappers. If a layer needs to take multiple input arrays, its input can be a nested data structure.
Most subclasses of Layer are encouraged to decorate
__call__
withchecked_layer_call
, which runs automatic shape checking and adds name scopes to aid debugging.- Parameters:
argument – An input value, or possibly a nested structure. Should be passed positionally by any caller; caller should not assume this is called “argument” exactly. Subclasses of Layer are free to rename this.
- Returns:
An output value, or possibly a nested structure.
- classmethod __init_subclass__(**kwargs)[source]#
Checks that new subclasses of Layer have wrapped
__call__
if needed.
- input_structure() shapecheck.StructureAnnotation [source]#
Returns the input structure of this layer.
The input structure of a layer is a PyTree describing the structure the layer expects to be called with, using the types from
penzai.core.shapecheck
. In particular, it will usually be a PyTree with leaves that are eithershapecheck.ArraySpec
nodes or that are uncheckedshapecheck.Wildcard
nodes.Subclasses of Layer that have complex or configuration-dependent logic in
__call__
are encouraged to override input_structure. This information will be used in two ways:It can give more informative error messages to users when they try to call a layer with the wrong input structure.
It will be visible in Treescope when the layer is pretty-printed.
If any ArraySpec contains dimension variables, these dimension variables are shared between
input_structure
andoutput_structure
. This means that the output structure and input structure must have consistent sizes if they are annotated with consistent variable names.A general suggestion is that each layer should check only the parts of its input that it needs to make assumptions about in order to do its job well. For instance, low-level operations may want to check the shapes of their inputs, but should use general dtypes (like
jnp.floating
) unless they specifically require a specific input dtype. Higher-level combinators that contain other layers should only check parts of their input that they use (e.g. a layer that unpacks a length-3 tuple should have aninput_structure
that is a length-3 tuple) but not the parts of their input that are passed through to their child layers.Note that, if you override this method, you must decorate
__call__
withchecked_layer_call
to ensure that the input structure is checked (orunchecked_layer_call
to opt out).If the attributes of this layer are set incorrectly, you can raise
MisconfiguredLayerError
to indicate that the layer cannot b successfully called with any input structure.
- output_structure() shapecheck.StructureAnnotation [source]#
Returns the output structure of this layer.
The output structure of a layer is a PyTree describing the structure of the layer’s return value, using the types from
penzai.core.shapecheck
. In particular, it will usually be a PyTree with leaves that are eithershapecheck.ArraySpec
nodes or that are uncheckedshapecheck.Wildcard
nodes.Subclasses of Layer that have complex or configuration-dependent logic in
__call__
are encouraged to overrideoutput_structure
. This information will be used in two ways:It serves as an assertion that the output matches the expectations of the layer, and guards against e.g. accidentally clobbering an axis name.
It will be visible in Treescope when the layer is pretty-printed.
If the attributes of this layer are set incorrectly, you can raise
MisconfiguredLayerError
to indicate that calling this layer will not succeed at runtime.See the documentation for
input_structure
for more details.