NamedGroup#
- final class penzai.nn.grouping.NamedGroup[source]#
Bases:
Layer
A layer that names an activation or a sequence of layers.
This layer does not do anything interesting on its own, but exists primarily to facilitate manipulation and inspection of a complex network:
The name will show up in
treescope
when inspecting the network interactively, giving context for the wrapped layers.NamedGroup
layers can be selected withpz.select
based on their name, using something like(...).at_instances_of(NamedGroup).where(lambda n: n.name == NAME)
When traced in JAX,
NamedGroup
layers add their name to the name scope, which will be visible in the TensorBoard profiler and in JAXPRs.
You can also omit the sublayers, in which case this serves as a lightweight way to assign a name to an activation (mostly useful in combination with
pz.select
).Suggestion for when to use
NamedGroup
vs subclassSequential
: If you have a function that builds a particular collection of sub-layers in a reusable way, consider subclassingSequential
and having that function be a constructor classmethod. If you just need to group some sublayers together, but want to name them for later reference, just usedNamedGroup
.You shouldn’t subclass
NamedGroup
; either subclassSequential
or define your own layer.- Variables:
name (str) – The name for the layer.
sublayers (Sequence[penzai.nn.layer.Layer]) – A sequence of layers to call in order. These are usually pz.nn.Layer instances, but are allowed to be other types of callable PyTree as well.
Methods
__init__
(name, sublayers)treescope_color
()__call__
(value, **side_inputs)Runs each of the sublayers in sequence.
Attributes
name
sublayers
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.