Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

How to Think in Penzai#

Penzai prioritizes legibility, visualization, and easy editing of neural network models. It strives to follow a simple mental model, avoid magic wherever possible, and decompose into modular tools that can be combined without getting in your way. This means that Penzai models are often structured somewhat differently than models in other libraries like PyTorch, Flax, or Keras.

This document explains the key principles behind Penzai’s design, and should teach you all you need to know to start using Penzai.

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
import collections
import dataclasses
import jax
import jax.numpy as jnp
from typing import Any, Callable, Sequence
import penzai
from penzai import pz
from penzai.example_models import simple_mlp

Principles#

1. What You See is What You Get#

The first central principle of Penzai, which influences almost every aspect of its design, is that everything is visualizable by default, and nothing is hidden from the user.

Penzai includes a powerful interactive IPython pretty-printer with automatic embedded array visualizations (called Treescope), which can be used to look inside any JAX-compatible data structure. You can enable Treescope like this:

pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

Penzai goes out of its way to make sure that the pretty-printer representation of a model tells you everything you need to know about it:

  • Every sublayer of the model is directly contained in its parent, and can be viewed by expanding it.

  • Every parameter is an attribute of the layer that owns it.

  • Every model is immutable, and all state is explicitly tracked as a node in the model tree.

  • Any shared parameters are tracked structurally using the data-effect system, not using Python references.

For instance, here’s what a simple MLP looks like in Penzai:

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([8, 32, 32, 8]),
    jax.random.key(0),
)
mlp

Try clicking to expand or collapse different sublayers! We’ve turned on automatic array visualization, so if you expand one of the parameters, you can immediately visualize its shape and array data.

Importantly, this isn’t just a pretty visualization of the model, it’s actually a fully-roundtrippable specification of the model structure. You can press r to enable roundtrip mode, and then directly copy and execute the pretty-printed output:

copied = penzai.example_models.simple_mlp.MLP( # Sequential
  sublayers=[
    penzai.nn.linear_and_affine.Affine( # Sequential
      sublayers=[
        penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(8, 32) ≈-0.0019 ±0.22 [≥-0.38, ≤0.38] nonzero:256>', original_id=23148094748192, original_type=jax.Array)), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),
        penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)),
        penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32,) ≈0.0 ±0.0 [≥0.0, ≤0.0] zero:32>', original_id=23148094743968, original_type=jax.Array)), name='Affine_0.AddBias.bias'), new_axis_names=()),
      ],
    ),
    penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.nn.linear_and_affine.Affine( # Sequential
      sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32, 32) ≈-0.0037 ±0.18 [≥-0.31, ≤0.31] nonzero:1_024>', original_id=23147983056352, original_type=jax.Array)), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32,) ≈0.0 ±0.0 [≥0.0, ≤0.0] zero:32>', original_id=23148002433952, original_type=jax.Array)), name='Affine_1.AddBias.bias'), new_axis_names=())],
    ),
    penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.nn.linear_and_affine.Affine( # Sequential
      sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32, 8) ≈-0.0052 ±0.23 [≥-0.38, ≤0.39] nonzero:256>', original_id=23148002427616, original_type=jax.Array)), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)', original_id=23147983059168, original_type=jax.Array)), name='Affine_2.AddBias.bias'), new_axis_names=())],
    ),
  ],
)
copied

And this produces a perfect copy of your model (including everything except the array data):

jax.tree_util.tree_structure(mlp) == jax.tree_util.tree_structure(copied)

As a user of a Penzai model, you should never have to worry about hidden state or Python object references. If it doesn’t show up in the pretty-printed output, it’s not part of the model.

2. Everything is Patchable#

The second core principle is that Penzai models are designed to be freely modified after they are built, including isolating small parts of larger models, combining models together, or inserting arbitrary logic at arbitrary points in a model’s forward pass.

Penzai includes a structure-rewriting utility, pz.select, which lets you make arbitrary modifications to Penzai models using .at(...).set(...)-style syntax. For instance, you can pull out parameters:

# Find the parameters:
pz.select(mlp).at_instances_of(pz.nn.Parameter)
# Extract them:
pz.select(mlp).at_instances_of(pz.nn.Parameter).get_sequence()

Or insert new logic:

@pz.pytree_dataclass
class HelloWorld(pz.Layer):
  def __call__(self, arg):
    pz.show("Hello world! My value:", arg)
    return arg
# Insert a new layer after each nonlinearity:
patched = (
    pz.select(mlp).at_instances_of(pz.nn.Elementwise).insert_after(HelloWorld())
)
pz.select(patched).at_instances_of(HelloWorld)
# Run it:
patched(pz.nx.ones({"features": 8}))

You can even click the “copy” button next to any part of the pretty-printed output to copy a path to that node, allowing you to extract or modify it:

# Copied by clicking above:
path_fn = (lambda root: root.sublayers[2].sublayers[0])
path_fn(mlp)

To make all of this work, Penzai models are designed to be as permissive as possible about their contents after construction. For instance, the MLP class doesn’t specifically require it’s children to be Affine layers (a.k.a. Dense layers), and doesn’t run the activation functions directly. Instead, it is a subclass of Sequential, and it just runs its sublayers in sequence without caring about their types. This means we are free to insert new logic into an MLP at runtume to customize its behavior, without having to change its original code.

3. Models Are (Just) Callable, JIT-able PyTree Data Structures#

Every Penzai model object is a frozen Python dataclass and a JAX Pytree. This means that all of the instance variables of Penzai models are explicitly typed and tracked, and that any Penzai model can be traversed using jax.tree_util.

jax.tree_util.tree_flatten(mlp)

Models own their parameters, so you can just call them directly:

mlp(pz.nx.ones({"features": 8}))

Or, just as easily, call one of their sublayers:

mlp.sublayers[2].sublayers[0]
mlp.sublayers[2].sublayers[0](pz.nx.ones({"features": 32}))

The PyTree leaves of a Penzai model are always the parameters and other NDArray contents. This means they can be passed through JAX transformations like jax.jit directly, without having to use special Penzai-specific wrappers.

@jax.jit
def my_func(model, arg):
  return model(arg)

my_func(mlp, pz.nx.ones({"features": 8}))

Penzai models work seamlessly with utilities designed for Python dataclasses or JAX Pytrees. In fact, the Treescope pretty-printer will pretty-print any dataclass, and the pz.select rewriting system can modify any JAX Pytree, without requiring special support for Penzai models in particular! Penzai’s tools are designed to put you in control and let you mix-and-match components from different libraries.

Penzai’s approach of callable PyTree dataclasses is heavily inspired by Equinox, so if you’re familiar with Equinox, you should feel right at home with Penzai! The main differences from Equinox:

  • Penzai doesn’t use filtered transformations; instead, parameters are explicitly annotated (discussed later), and all non-array data should be marked as such using dataclasses.field(metadata={"pytree_node": False}).

  • Penzai layers use an explicit @pz.pytree_dataclass decorator, which makes it obvious to readers of your code that this class uses dataclass semantics. pz.pytree_dataclass also lets you customize the dataclass arguments, and catches common footguns of Python dataclasses (such as the ordering of attributes in dataclass inheritance).

  • By convention, Penzai layers do not override __init__. Instead, construction is done using a separate class method (often called MyLayer.from_config(...)). This is to ensure that you can always rebuild the layer from its pretty-printed representation, even if you’ve patched it.

  • Conventions for idiomatic Penzai layers differ somewhat from Equinox modules, as discussed in the remaining principles below.

4. Axes Are Referenced By Name, But Used Positionally#

Axis ordering can make it harder to reason about what complex models are doing, especially when trying to visualize or intervene on internal activations, or when using models from an unfamiliar codebase. It’s often easier to refer to axes by name. But you shouldn’t have to learn a whole new array API just to use named axes; the existing Numpy and JAX APIs are pretty good!

Penzai strikes a middle ground using a lightweight locally-positional named-axis system, defined in a single file and with a minimal API surface. In short:

  • The pz.nx.NamedArray class wraps an ordinary array, and assigns each axis to either a position or a name (but not both).

  • You can convert positional axes to named ones using .tag(...), or convert named axes back to positional axes using .untag(...).

  • Any JAX function can be lifted using pz.nx.nmap. The lifted function will act normally over the positional axes but will be automatically vectorized over all of the named axes (using jax.vmap internally). Only NamedArray arguments are processed in this way; other arguments are just passed through.

  • Standard array methods and operators (e.g. .sum(), +, or slicing) are also lifted so that they operate over positional axes and vectorize over named axes.

  • By convention, Penzai layers use axis names to define their interface, but then use .untag, nmap, and .tag to implement their internal logic.

For instance, here’s how you might take a softmax over a vocabulary axis:

# Start with a JAX array:
array = jax.random.normal(jax.random.key(0), [8, 32])
# Wrap it as a named array:
wrapped = pz.nx.wrap(array)
# Assign names:
named = wrapped.tag("batch", "vocabulary")
# Visualize it:
named
# Un-tag the vocabulary axis:
untagged = named.untag("vocabulary")
# Map the ordinary JAX softmax function over the temporary positional axis:
softmaxed = pz.nx.nmap(jax.nn.softmax)(untagged, axis=0)
# Tag the positional axis with a name again:
softmaxed.tag("vocabulary")

And here’s how you might wrap that in an idiomatic layer:

@pz.pytree_dataclass
class Softmax(pz.Layer):
  axis_name: str = dataclasses.field(metadata={"pytree_node": False})
  def __call__(self, arg):
    # Write the logic as if the argument is one dimensional:
    arr = arg.untag(self.axis_name)
    assert len(arr.positional_shape) == 1
    result = pz.nx.nmap(jax.nn.softmax)(arr, axis=0)
    # Then re-bind names at the end:
    return result.tag(self.axis_name)
layer = Softmax("vocabulary")
layer
layer(named)

Because everything vectorizes over names by default, Penzai models can usually be used with arbitrary numbers of batch axes at runtime as long as you give them unique names. You can even insert new layers that manipulate specific batch axes by name (e.g. copying activations from one input to another), without interfering with any of the shapes in the rest of your model.

5. Parameters Are Tagged And Named#

In Penzai, every learnable parameter is identified by being an instance of pz.nn.Parameter. Since models are freely patchable at runtime, the location of a parameter may change after it is built, so every parameter must have a unique string name.

This makes it easy to distinguish learnable parameters from frozen parameters or arraylike hyperparameters. For instance, you can get a parameter dictionary using pz.select:

{
    param.name: param.value
    for param in pz.select(mlp).at_instances_of(pz.nn.Parameter).get_sequence()
}

Then freeze some parameters:

mlp_with_frozen_bias = (
    pz.select(mlp)
    .at_instances_of(pz.nn.AddBias)
    .at_instances_of(pz.nn.Parameter)
    .apply(lambda x: pz.nn.FrozenParameter(name=x.name, value=x.value))
)
mlp_with_frozen_bias

And get the new learnable parameters from the modified model:

{
    param.name: param.value
    for param in (
        pz.select(mlp_with_frozen_bias)
        .at_instances_of(pz.nn.Parameter)
        .get_sequence()
    )
}

Or substitute values for them:

(
    pz.select(mlp_with_frozen_bias)
    .at_instances_of(pz.nn.Parameter)
    .apply(lambda param: pz.nn.Parameter(
        name=param.name,
        value=jax.tree_util.tree_map(jnp.zeros_like, param.value),
    ))
)

You can use this to take gradients with respect to a model’s parameters using jax.grad:

model = mlp_with_frozen_bias

def my_loss(diffble_params):
  model_with_params = (
      pz.select(model)
      .at_instances_of(pz.nn.Parameter)
      .apply(lambda param: pz.nn.Parameter(
          name=param.name, value=diffble_params[param.name]
      ))
  )
  result = model_with_params(
      pz.nx.ones({"features": 8})
  ).untag("features").unwrap()
  return jnp.sum(jnp.square(result))

jax.grad(my_loss)({
    param.name: param.value
    for param in (
        pz.select(model)
        .at_instances_of(pz.nn.Parameter)
        .get_sequence()
    )
})

There’s also a simple batteries-included training loop in penzai.toolshed_basic_training if you don’t need to do anything fancy!

6. Each Layer Takes One Argument And Does One Thing#

Penzai models are built by composing layers, where each layer implements the following interface:

class Layer(pz.Struct, abc.ABC):
  @abc.abstractmethod
  def __call__(self, argument: Any, /) -> Any:
    ...

In short:

  • Each layer defines a method __call__, which enables it to be called directly like a function, and which contains all of the layer’s runtime logic.

  • __call__ always takes exactly one argument, which must be passed positionally. (If necessary, this argument can be a tuple, dictionary, pz.Struct, or other JAX Pytree.)

  • Whenever possible, idiomatic Penzai models should not contain Python conditional branches in their __call__. You should be able to JIT-compile the __call__ of any model, and there should generally be only a single control flow path through it.

Penzai uses this convention because it makes it straightforward to compose layers with each other. For instance, there’s an unambiguous way to pass the output of one layer as the input of another layer, since we know the other layer takes a single input.

What about situations where we need to pass extra information to a layer to determine its runtime behavior? The idiomatic way to do this in Penzai depends on the specific type of information:

  • Configuration metadata, such as the input or output axis names for pz.nn.Linear or the activation function for pz.nn.Elementwise, are stored as attributes on the layer, and set when the layer is initially built.

  • Arrays and array-like side inputs, such as attention masks, token positions, or key-value caches, are typically stored as special “effect” attributes and handled by Penzai’s data-effects system (discussed later).

  • Top-level model objects with multiple inputs can define a pz.Struct that contains all of the information they need, and take an instance of that struct as their positional argument.

  • What about configuration arguments used in other libraries, such as “whether or not we should enable dropout” or “whether we are doing scoring or autoregressive decoding”, which change the behavior of the layer in different modes? Trick question! In Penzai, these different modes should usually be represented using different classes. You can then swap out model components using pz.select to switch between different model behaviors.

The emphasis on “doing one thing” also extends to composite layers. In Penzai, composite layers are usually defined as direct compositions of simpler layers, by subclassing the pz.nn.Sequential combinator. Then, their responsibility at runtime is just to call their children in sequence, which means it’s easy to insert new logic without interfering with the model’s computation. We’ve already seen an example of this: the MLP model and Affine blocks in our mlp are both subclasses of pz.nn.Sequential.

More complex combinators also tend to adhere to this pattern. For instance, the core Attention block in Penzai is purely a dataflow combinator, defined as

@struct.pytree_dataclass
class Attention(layer_base.Layer):
  input_to_query: layer_base.LayerLike
  input_to_key: layer_base.LayerLike
  input_to_value: layer_base.LayerLike
  query_key_to_attn: layer_base.LayerLike
  attn_value_to_output: layer_base.LayerLike

  def __call__(self, x: named_axes.NamedArray) -> named_axes.NamedArray:
    query = self.input_to_query(x)
    key = self.input_to_key(x)
    value = self.input_to_value(x)
    attn = self.query_key_to_attn((query, key))
    output = self.attn_value_to_output((attn, value))
    return output

All of the specific logic of computing positional embeddings, applying attention masks, and computing the softmax weights are left to the child layers, which makes it easy to go in and capture intermediates or intervene on their behaviors at any point, without needing to change the attention implementation. Attention itself just does a single thing: manage the routing of data between the different components, during training or scoring mode.

If you want to do autoregressive decoding, you can swap out Attention blocks for KVCachingAttention blocks using something like

(
  pz.select(model)
  .at_instances_of(pz.nn.Attention)
  .apply(lambda attn: pz.nn.KVCachingAttention.from_uncached(attn, **kwargs))
)

This produces a copy of your model that additionally manages and updates KV caches, while still supporting arbitrary child layer logic and without changing any of the rest of your model. (See the “Gemma from Scratch” tutorial for more info on autoregressive decoding!)

7. Configuration Happens During Construction (Not __call__)#

As discussed above, Penzai layers avoid passing configuration arguments at runtime, and avoid making assumptions about their child layers and parameters as much as possible. However, it’s still important for layers and models to be able to configure themselves and initialize their parameters. In Penzai, all of this happens when the layers are initially constructed.

By convention, Penzai layers configure themselves using a class method, often called from_config(cls, ...) (to avoid overriding __init__). from_config, in turn, takes all of the configuration arguments that are necessary to initialize the model, and uses them to set up their sublayers and parameter initializers. To separate the construction of a model from the initialization of parameters, parameters are initially configured as pz.nn.UninitializedParameter instances.

We can see this by calling the from_config method of simple_mlp.MLP:

simple_mlp.MLP.from_config(
    feature_sizes=[8, 32, 32, 8],
    activation_fn=jax.nn.gelu,
)

Notice that the arguments to from_config aren’t actually stored on the MLP itself. Instead, they are simply used to configure and set up the list of sublayers. In general, the configuration arguments of complex models will often “vanish” in this way after the model is initially built.

In fact, all of the custom logic of MLP and Affine is defined in the from_config methods, not __call__. Once initialized, you are free to remove them entirely without affecting the behavior of the model. For instance, this model is equivalent:

pz.nn.inline_groups(
    pz.nn.Sequential([
        simple_mlp.MLP.from_config(
            feature_sizes=[8, 32, 32, 8],
            activation_fn=jax.nn.gelu,
        )
    ]),
    parent_filter=lambda _: True,
    child_filter=lambda _: True,
)

This pattern also applies to layers that are designed for hot-swapping. For instance, the KVCachingAttention block defines a classmethod .from_uncached that converts an Attention block into a KVCachingAttention, which takes ownership of the children of that Attention block and then discards the original block.

8. Effects Enable Complex Data Flow#

Complex models often need to pass extra input context or random number generators around in addition to primary activation stream, which isn’t natural to express directly using single-input layers. Additionally, models sometimes need to share parameters between individual layers, or update running state variables.

To support models with this kind of dataflow without interfering with the rest of Penzai’s design conventions, Penzai also includes a simple but powerful “data effect” system, built using pz.select and inspired by effect systems in functional programming. The key features of data effects:

  • Effects are explicitly represented as typed attributes in the model Pytree, and can be copied and manipulated just like ordinary layers. (Again, what you see is what you get, and everything is patchable!)

  • Effects are handled using explicit handlers, which use pz.select to replace the effect attributes with concrete temporary implementations.

  • All effects are handled at the Pytree level. The semantics of effects are determined by their handlers, and layers that don’t use effects can safely ignore them. In fact, the implementation of effects is totally separate from the definition of pz.Layer, and you are free to define your own effects if the existing ones don’t meet your needs.

Here’s an example of using data effects to implement randomness. We can initialize an MLP that uses dropout:

dropout_mlp = pz.nn.initialize_parameters(
    simple_mlp.DropoutMLP.from_config(
        feature_sizes=[8, 32, 32, 8],
        drop_rate=0.1,
    ),
    jax.random.PRNGKey(0),
)
dropout_mlp

The StochasticDropout layers each contain a RandomRequest, which indicates that they need to receive random numbers in order to run. We can handle these requests using a handler:

handled_dropout_mlp = pz.de.WithRandomKeyFromArg.handling(dropout_mlp)
handled_dropout_mlp

The RandomRequests have now been replaced with HandledRandomRef nodes, which explicitly refer to the ID of the new handler wrapper. We can then call the handled_dropout_mlp object, which will inject a random number generator into the two layers that need it:

handled_dropout_mlp((pz.nx.ones({"features": 8}), jax.random.key(1)))
handled_dropout_mlp((pz.nx.ones({"features": 8}), jax.random.key(2)))

The StochasticDropout layers each interface with the handler by calling self.rng.next_key(), and the WithRandomKeyFromArg is responsible for ensuring that the rng attributes are instantiated with something implementing this method.

Penzai uses the data effects system to implement a variety of more complex features:

  • Side inputs (such as attention masks)

  • Side outputs (such as capturing intermediate activations)

  • Local state (such as key/value caching)

  • Random number generation (such as the dropout layers above)

  • Parameter sharing (using a custom ParameterLike subclass)

The data effect system makes it possible to add these features to any model without needing to manually thread arguments through layers that don’t use them. You can read more about the data effect system in the separate data effects tutorial.

Putting It All Together: A Basic Penzai Neural Network#

To show how these principles interact, here’s how we might implement a neural network from scratch in Penzai. We’ll focus on re-implementing a basic MLP (like the running example above), and omit a few advanced features to keep things simple.

An MLP is composed of a sequence of steps, including linear operations, biases, and elementwise activations. We can implement each of these using a separate layer so that we can manipulate them after the model is built, and define each using named axes:

@pz.pytree_dataclass
class SimpleLinear(pz.Layer):
  """A simple linear layer with a single input/output axis."""

  # Parameters are annotated as `ParameterLike` to allow swapping them out after
  # initialization.
  kernel: pz.nn.ParameterLike[pz.nx.NamedArray]

  # Non-Pytree fields (which are not arraylike) should be annotated as such to
  # tell JAX not to try to convert them:
  features_axis: str = dataclasses.field(metadata={"pytree_node": False})

  def __call__(self, x: pz.nx.NamedArray, /) -> pz.nx.NamedArray:
    """Multiplies the input by the learned kernel."""
    # pos_x has one positional axis
    pos_x = x.untag(self.features_axis)
    # pos_kernel has two positional axes
    pos_kernel = self.kernel.value.untag("out_features", "in_features")
    # We can combine them using ordinary positional semantics:
    pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
    return pos_y.tag(self.features_axis)

  @classmethod
  def from_config(
      cls, in_features: int, out_features: int, features_axis: str = "features",
  ) -> "SimpleLinear":
    """Constructs a linear layer from configuration arguments."""
    def _initializer(key):
      arr = jax.nn.initializers.xavier_normal()(
          key, (out_features, in_features)
      )
      return pz.nx.wrap(arr).tag("out_features", "in_features")

    return cls(
        # Configure parameters using UninitializedParameter to avoid device
        # computation until we need parameter values:
        kernel=pz.nn.UninitializedParameter(
            initializer=_initializer, name="kernel"
        ),
        features_axis=features_axis,
    )
@pz.pytree_dataclass
class SimpleBias(pz.Layer):
  """A simple bias layer."""
  # The SimpleBias layer doesn't need to store its output axis name at all!
  bias: pz.nn.ParameterLike[pz.nx.NamedArray]

  def __call__(self, x: pz.nx.NamedArray, /) -> pz.nx.NamedArray:
    """Adds a bias to the input."""
    return x + self.bias.value  # Automatically vectorized!

  @classmethod
  def from_config(
      cls, features: int, features_axis: str = "features",
  ) -> "SimpleBias":
    """Constructs a bias layer from configuration arguments."""
    return cls(
        bias=pz.nn.UninitializedParameter(
            initializer=lambda _: pz.nx.zeros({features_axis: features}),
            name="bias"
        )
    )
@pz.pytree_dataclass
class SimpleElementwise(pz.Layer):
  """A simple elementwise layer."""
  fn: Callable[[jax.Array], jax.Array] = dataclasses.field(
      metadata={"pytree_node": False}
  )

  def __call__(self, x: pz.nx.NamedArray, /) -> pz.nx.NamedArray:
    """Runs the activation function."""
    return pz.nx.nmap(self.fn)(x)

  # No need for `from_config`, since it would be the same as `__init__`.

We can then define a top-level MLP layer as a subclass of Sequential:

@pz.pytree_dataclass
class SimpleMLP(pz.nn.Sequential):
  # sublayers is inherited from Sequential, but we restate it here for clarity.
  sublayers: list[pz.LayerLike]

  # __call__ is inherited from Sequential, so no need to reimplement it! In
  # fact, Sequential.__call__ is marked with @typing.final so you don't
  # accidentally override it.

  @classmethod
  def from_config(
      cls,
      feature_sizes: Sequence[int],
      activation: Callable[[jax.Array], jax.Array] = jax.nn.relu,
      features_axis: str = "features",
  ) -> "SimpleMLP":
    """Constructs a MLP with uninitialized parameters."""
    sublayers = []
    for i in range(len(feature_sizes) - 1):
      # We need to ensure parameter name uniqueness ourselves:
      sublayers.append(pz.nn.add_parameter_prefix(f"block_{i}",
          SimpleLinear.from_config(
              feature_sizes[i], feature_sizes[i + 1], features_axis
          )
      ))
      sublayers.append(pz.nn.add_parameter_prefix(f"block_{i}",
          SimpleBias.from_config(
              feature_sizes[i + 1], features_axis
          )
      ))
      if i < len(feature_sizes) - 2:
        sublayers.append(SimpleElementwise(activation))
    return cls(sublayers)

Next, we can configure it and print out what we got to make sure it matches what we expect:

model_def = SimpleMLP.from_config(
    feature_sizes=[8, 32, 32, 8],
    activation=jax.nn.relu,
    features_axis="features",
)
model_def

Then we can initialize the parameters, perhaps under JIT:

model_at_init = jax.jit(pz.nn.initialize_parameters)(
    model_def, jax.random.key(42)
)
model_at_init