NamedGroup

NamedGroup#

final class penzai.nn.grouping.NamedGroup[source]#

Bases: Layer

A layer that names an activation or a sequence of layers.

This layer does not do anything interesting on its own, but exists primarily to facilitate manipulation and inspection of a complex network:

  • The name will show up in treescope when inspecting the network interactively, giving context for the wrapped layers.

  • NamedGroup layers can be selected with pz.select based on their name, using something like

    (...).at_instances_of(NamedGroup).where(lambda n: n.name == NAME)
    
  • When traced in JAX, NamedGroup layers add their name to the name scope, which will be visible in the TensorBoard profiler and in JAXPRs.

You can also omit the sublayers, in which case this serves as a lightweight way to assign a name to an activation (mostly useful in combination with pz.select).

Suggestion for when to use NamedGroup vs subclass Sequential: If you have a function that builds a particular collection of sub-layers in a reusable way, consider subclassing Sequential and having that function be a constructor classmethod. If you just need to group some sublayers together, but want to name them for later reference, just used NamedGroup.

You shouldn’t subclass NamedGroup; either subclass Sequential or define your own layer.

Variables:
  • name (str) – The name for the layer.

  • sublayers (Sequence[Callable[[Any], Any]]) – A sequence of layers to call in order. These are usually pz.Layer instances, but are allowed to be other types of callable PyTree as well.

Methods

__init__(name, sublayers)

treescope_color()

__call__(value)

Runs each of the sublayers in sequence.

Attributes

name

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

__call__(value: Any) Any[source]#

Runs each of the sublayers in sequence.

Parameters:

value – The input to the layer.

Returns:

The output of the final sublayer.