Layer#
- class penzai.nn.layer.Layer[source]#
-
Abstract base class for neural network layers and other 1-arg callables.
Methods
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
__call__
(argument, /, **side_inputs)Abstract call method for a layer.
Inherited Methods
(expand to view inherited methods)
__init__
()attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- abstract __call__(argument: Any, /, **side_inputs: Any) Any [source]#
Abstract call method for a layer.
Layers are model components that take one main input and optional side inputs, and produce a single output. By convention, almost all model components in a Penzai model are instances of Layer, making it possible to easily compose them with other layers and wrappers. If a layer needs to take multiple input arrays, its input can be a nested data structure.
Both arguments should be passed positionally by any caller; callers should not assume they have particular names, and subclasses of Layer are free to rename them.
- Parameters:
argument – The primary input to the layer. Usually either an array or a nested structure of arrays.
**side_inputs – Arbitrary side context available to the layer. Each should usually be an array, Variable, or structure of arrays or variables. Layers must accept arbitrary side input keyword arguments and should ignore side inputs that they do not use.
- Returns:
An output value, or possibly a nested structure.
- bind_variables(variables: Iterable[vars_lib.AbstractVariable | vars_lib.AbstractVariableValue], allow_unused: bool = False) Layer [source]#
Convenience function to bind variables to a layer.
layer.bind_variables(variables)
is a simple alias forpz.bind_variables(layer, variables)
.- Parameters:
variables – The collection of variables (or frozen variable values) to insert.
allow_unused – Whether to ignore variables that do not have any matching slot (in which case they will not be inserted).
- Returns:
A copy of this layer with variables (re-)inserted.
- final stateless_call(variable_values: Iterable[vars_lib.AbstractVariableValue], argument: Any, /, **side_inputs: dict[Any, Any]) tuple[Any, tuple[vars_lib.AbstractVariableValue, ...]] [source]#
Calls a layer with temporary variables, without modifying its state.
This is a convenience method for:
freezing any variables currently inside the model if there are any (so that we don’t mutate them unexpectedly)
creating temporary unfrozen copies of each variable value passed as an argument,
binding the temporary unfrozen copies to the layer (allowing them to be modified while the layer runs),
calling the layer,
extracting and re-freezing the temporary vars, and
returning the result and the updated frozen vars.
In combination with
variables.unbind_variables
, this can be used to call a stateful layer in a functional way.Also note that matching variables will also be bound inside the argument and side inputs. This can simplify the process of calling a model with stateful arguments such as random streams.
- Parameters:
variable_values – Initial values for each variable that should be mutable while the layer runs. These will be substituted for variable slots in the layer.
argument – The argument to pass to the layer. (May also contain variable slots that will be bound to variables from
variable_values
.)**side_inputs – Arbitrary side context available to the layer. (May also contain variable slots values that will be bound to variables from
variable_values
.)
- Returns:
A tuple
(result, updated_vars)
whereresult
is the result of calling the layer andupdated_vars
is a list of updated variable values.