Layer#

class penzai.core.layer.Layer[source]#

Bases: Struct, ABC

Abstract base class for neural network layers and other 1-arg callables.

Methods

input_structure()

Returns the input structure of this layer.

output_structure()

Returns the output structure of this layer.

__call__(argument, /)

Abstract call method for a layer.

Inherited Methods

(expand to view inherited methods)

__init__()

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

abstract __call__(argument: Any, /) Any[source]#

Abstract call method for a layer.

Layers are submodels that take a single input and produce a single output. By convention, almost all model components in a Penzai model are instances of Layer, making it possible to easily compose them with other layers and wrappers. If a layer needs to take multiple input arrays, its input can be a nested data structure.

Most subclasses of Layer are encouraged to decorate __call__ with checked_layer_call, which runs automatic shape checking and adds name scopes to aid debugging.

Parameters:

argument – An input value, or possibly a nested structure. Should be passed positionally by any caller; caller should not assume this is called “argument” exactly. Subclasses of Layer are free to rename this.

Returns:

An output value, or possibly a nested structure.

classmethod __init_subclass__(**kwargs)[source]#

Checks that new subclasses of Layer have wrapped __call__ if needed.

input_structure() shapecheck.StructureAnnotation[source]#

Returns the input structure of this layer.

The input structure of a layer is a PyTree describing the structure the layer expects to be called with, using the types from penzai.core.shapecheck. In particular, it will usually be a PyTree with leaves that are either shapecheck.ArraySpec nodes or that are unchecked shapecheck.Wildcard nodes.

Subclasses of Layer that have complex or configuration-dependent logic in __call__ are encouraged to override input_structure. This information will be used in two ways:

  • It can give more informative error messages to users when they try to call a layer with the wrong input structure.

  • It will be visible in Treescope when the layer is pretty-printed.

If any ArraySpec contains dimension variables, these dimension variables are shared between input_structure and output_structure. This means that the output structure and input structure must have consistent sizes if they are annotated with consistent variable names.

A general suggestion is that each layer should check only the parts of its input that it needs to make assumptions about in order to do its job well. For instance, low-level operations may want to check the shapes of their inputs, but should use general dtypes (like jnp.floating) unless they specifically require a specific input dtype. Higher-level combinators that contain other layers should only check parts of their input that they use (e.g. a layer that unpacks a length-3 tuple should have an input_structure that is a length-3 tuple) but not the parts of their input that are passed through to their child layers.

Note that, if you override this method, you must decorate __call__ with checked_layer_call to ensure that the input structure is checked (or unchecked_layer_call to opt out).

If the attributes of this layer are set incorrectly, you can raise MisconfiguredLayerError to indicate that the layer cannot b successfully called with any input structure.

output_structure() shapecheck.StructureAnnotation[source]#

Returns the output structure of this layer.

The output structure of a layer is a PyTree describing the structure of the layer’s return value, using the types from penzai.core.shapecheck. In particular, it will usually be a PyTree with leaves that are either shapecheck.ArraySpec nodes or that are unchecked shapecheck.Wildcard nodes.

Subclasses of Layer that have complex or configuration-dependent logic in __call__ are encouraged to override output_structure. This information will be used in two ways:

  • It serves as an assertion that the output matches the expectations of the layer, and guards against e.g. accidentally clobbering an axis name.

  • It will be visible in Treescope when the layer is pretty-printed.

If the attributes of this layer are set incorrectly, you can raise MisconfiguredLayerError to indicate that calling this layer will not succeed at runtime.

See the documentation for input_structure for more details.