Copyright 2024 The Penzai Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
How to Think in Penzai#
Penzai prioritizes legibility, visualization, and easy editing of neural network models. It strives to follow a simple mental model, avoid magic wherever possible, and decompose into modular tools that can be combined without getting in your way. This means that Penzai models are often structured somewhat differently than models written with other libraries.
This document explains the key principles of Penzai’s neural network system.
(Note: The current “V2” neural network system differs from the “V1” neural network system in Penzai’s initial release. For a summary of the differences, see this page.)
try:
import penzai
except ImportError:
!pip install penzai[notebook]
import collections
import dataclasses
import jax
import jax.numpy as jnp
from typing import Any, Callable, Sequence
import treescope
import penzai
from penzai import pz
from penzai.models import simple_mlp
Principles#
1. What You See is What You Get#
The first central principle of Penzai is that models are designed to be visualizable by default.
Penzai integrates with Treescope, a powerful interactive IPython pretty-printer with automatic embedded array visualizations. (In fact, Treescope was originally designed as the pretty-printer for Penzai!) You can enable Treescope like this:
treescope.basic_interactive_setup(autovisualize_arrays=True)
Penzai goes out of its way to make sure that the pretty-printer representation of a model tells you everything you need to know about it:
Every sublayer of the model is directly contained in its parent, and can be viewed by expanding it.
Every parameter is an attribute of the layer that owns it, and all attributes are statically known and type-annotated.
Most model objects are immutable, and all stateful modifications are constrained to explicit “Variable” objects.
For instance, here’s what a simple MLP looks like in Penzai:
mlp = simple_mlp.MLP.from_config(
name="mlp",
init_base_rng=jax.random.key(0),
feature_sizes=[8, 32, 32, 8]
)
mlp
Try clicking to expand or collapse different sublayers! We’ve turned on automatic array visualization, so if you expand one of the parameters, you can immediately visualize its shape and array data.
Importantly, this isn’t just a pretty visualization of the model, it’s actually a full specification of the model structure. In fact, if you remove the parameters first, you can copy and paste the pretty printed output to rebuild the model structure! Every attribute of the model object appears in the pretty printed output, so if it doesn’t show up in the pretty-printed output, it’s not part of the model.
(Tip: You can click on a pretty-printed output and press r to add qualified names to every type.)
# Try clicking the output below and pressing `r`!
unbound_mlp, _ = pz.unbind_params(mlp)
unbound_mlp
# Copying and pasting this pretty-printed output rebuilds the model:
copied = penzai.models.simple_mlp.MLP( # Sequential
sublayers=[
penzai.nn.linear_and_affine.Affine( # Sequential
sublayers=[
penzai.nn.linear_and_affine.Linear(weights=penzai.core.variables.ParameterSlot(label='mlp/Affine_0/Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),
penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)),
penzai.nn.linear_and_affine.AddBias(bias=penzai.core.variables.ParameterSlot(label='mlp/Affine_0/AddBias.bias'), new_axis_names=()),
],
),
penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
penzai.nn.linear_and_affine.Affine( # Sequential
sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.core.variables.ParameterSlot(label='mlp/Affine_1/Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.core.variables.ParameterSlot(label='mlp/Affine_1/AddBias.bias'), new_axis_names=())],
),
penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
penzai.nn.linear_and_affine.Affine( # Sequential
sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.core.variables.ParameterSlot(label='mlp/Affine_2/Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.core.variables.ParameterSlot(label='mlp/Affine_2/AddBias.bias'), new_axis_names=())],
),
],
)
copied == unbound_mlp
2. Models Are Callable, Patchable Data Structures#
To make it easier to inspect and modify models, Penzai prioritizes treating models as user-modifiable data structures, rather than as opaque objects. Every Penzai model object is a frozen Python dataclass, which means that all of the instance variables of Penzai models are explicitly type-annotated and tracked.
Models can be called with an input argument in order to run the model forward pass:
mlp(pz.nx.ones({"features": 8}))
You can also just as easily call one of their sublayers:
mlp.sublayers[2].sublayers[0]
mlp.sublayers[2].sublayers[0](pz.nx.ones({"features": 32}))
However, you can also easily modify the model forward pass by modifying the model data structure.
Penzai models are designed to be freely modified after they are built, including isolating small parts of larger models, combining models together, or inserting arbitrary logic at arbitrary points in a model’s forward pass. And Penzai includes a structure-rewriting utility, pz.select, which lets you make arbitrary modifications to Penzai models using .at(...).set(...)-style syntax. For instance, you can find and remove particular layers:
# Find bias layers
pz.select(mlp).at_instances_of(pz.nn.AddBias)
# Remove them:
pz.select(mlp).at_instances_of(pz.nn.AddBias).remove_from_parent()
Or insert new layers to run new logic:
@pz.pytree_dataclass
class HelloWorld(pz.nn.Layer):
def __call__(self, arg, **side_inputs):
pz.show("Hello world! My value:", arg)
return arg
# Insert a new layer after each nonlinearity:
patched = (
pz.select(mlp).at_instances_of(pz.nn.Elementwise).insert_after(HelloWorld())
)
pz.select(patched).at_instances_of(HelloWorld)
# Run it:
patched(pz.nx.ones({"features": 8}))
Penzai models are registered as JAX Pytree nodes (similar to Equinox) so that any Penzai model can be traversed using jax.tree_util. In fact, the pz.select utility is a general-purpose utility for modifying any JAX Pytree! Modifications to Penzai models always occur by making a modified copy of the model, instead of being stored as global state. For instsance, the model patched above is a modified copy of mlp, which behaves differently when it is run.
Penzai models are also designed to be as permissive as possible about their contents after construction. For instance, the MLP class doesn’t specifically require it’s children to be Affine layers (a.k.a. Dense layers), and doesn’t run the activation functions directly. Instead, it is a subclass of Sequential, and it just runs its sublayers in sequence without caring about their types. This means we are free to insert new logic into an MLP at runtume to customize its behavior, without having to change its original code.
3. Parameters And State Are Tracked With Explicit Variable Nodes#
As Pytree nodes, Penzai model objects are immutable, which simplifies working with JAX and allows you to safely make copies of your model that behave in different ways. However, models often require keeping track of mutable state:
Parameters are often updated by gradient descent, and shared parameters need to stay in sync.
Some model configurations, like key-value caching in Transformers, require keeping track of per-layer model states while the model runs.
It can be useful to save intermediate model activations so that you can inspect or modify them later.
Penzai supports this using “variable” nodes, which are explicit “pockets of mutable state” inside Penzai models. Each Penzai model tree has two types of leaves:
JAX arrays and scalars, which are immutable and often represent hyperparameters, and
Variable objects, which can be modified, and come in two variants:
Parameters, which are usually modified by optimizers (not by the model),StateVariables, which are usually updated as the model runs.
For instance, the leaves of the mlp above are its parameters, each of which is an instance of Parameter:
jax.tree_util.tree_leaves(mlp)
The same parameter can appear multiple times in a single model. As an example, here’s a model that repeats the same layer multiple times, along with a scaling factor:
layer = pz.nn.Affine.from_config(
name="shared_layer",
init_base_rng=jax.random.key(0),
input_axes={"features": 8},
output_axes={"features": 8},
)
my_model_with_repeats = pz.nn.Sequential([
layer,
pz.nn.Elementwise(jax.nn.relu),
pz.nn.ConstantRescale(0.5),
layer,
])
my_model_with_repeats
In this case, the PyTree leaves of this model will repeat the parameters twice, and also include the rescaling hyperparameter:
jax.tree_util.tree_leaves(my_model_with_repeats)
To extract and deduplicate the parameters, you can use the helper function pz.unbind_params. This produces:
A copy of the model with each
Parameterreplaced with aParameterSlotplaceholder,A tuple of all unique parameters in the model.
unbound_model, params = pz.unbind_variables(my_model_with_repeats)
pz.show("unbound_model:", unbound_model)
pz.show("params:", params)
These parameters can then be substituted back into the model using pz.bind_variables.
You can implement stateful layers using a similar mechanism, but with StateVariable instead of Parameter. Here’s a layer that stores its intermediate activation into a list:
@pz.pytree_dataclass
class SaveIntermediate(pz.nn.Layer):
saved: pz.StateVariable[list[Any]]
def __call__(self, x: Any, **unused_side_inputs) -> Any:
self.saved.value = self.saved.value + [x]
return x
We can insert two copies of it into our MLP, and then call it to retrieve the values:
var = pz.StateVariable(value=[], label="my_intermediate")
saving_model = (
pz.select(mlp)
.at_instances_of(pz.nn.Elementwise)
.insert_after(SaveIntermediate(var))
)
saving_model
saving_model(pz.nx.ones({"features": 8}))
var
You can similarly unbind state variables using pz.unbind_state_vars:
unbound_saving_mlp, all_vars = pz.unbind_state_vars(saving_model)
pz.show("unbound_saving_mlp:", unbound_saving_mlp)
pz.show("all_vars:", all_vars)
Or unbind both parameters and states using pz.unbind_variables:
unbound_saving_mlp, all_vars = pz.unbind_variables(saving_model)
pz.show("unbound_saving_mlp:", unbound_saving_mlp)
pz.show("all_vars:", all_vars)
To make it easier to manipulate variables with JAX, any variable can be “frozen” using the .freeze method or the pz.freeze_variables function:
pz.freeze_variables(all_vars)
Frozen variables are JAX PyTrees, and can be safely passed through JAX transformations. Penzai models also support a “pure” interface that lets you pass frozen variables in and get new frozen variables out:
# Freeze parameters:
frozen_param_model = pz.freeze_params(saving_model)
# Unbind and freeze state vars:
unbound_frozen_model, state_vars = pz.unbind_state_vars(
frozen_param_model, freeze=True
)
state_var_values = pz.freeze_state_vars(state_vars)
# Call it in "pure" style, tracking modifications to the intermediates variable.
# The input and output variables are frozen, but the variable can be locally
# modified while the model runs:
output, updated_var_states = unbound_frozen_model.stateless_call(
[pz.StateVariableValue(label='my_intermediate', value=[])],
pz.nx.ones({"features": 8})
)
pz.show("output:", output)
pz.show("updated_var_states:", updated_var_states)
You may need to freeze variables in order to pass them through JAX’s function transformations. (For jit, Penzai includes a wrapped version called pz.variable_jit that handles this for you.)
4. Each Layer Has The Same Signature And Does A Single Thing#
Penzai models are built by composing layers, where each layer implements the following interface:
class Layer(pz.Struct, abc.ABC):
@abc.abstractmethod
def __call__(self, argument: Any, /, **side_inputs) -> Any:
...
In short:
Each layer defines a method
__call__, which enables it to be called directly like a function, and which contains all of the layer’s runtime logic.__call__always takes exactly one positional argument, which is its input from the previous layer. (If necessary, this argument can be a tuple, dictionary,pz.Struct, or other JAX Pytree.)__call__also takes an arbitrary number of keyword arguments, which are side inputs. Side inputs can be used for information like attention masks or random number generators, and are usually shared across every layer in the model. Layers should ignore side inputs that they do not recognize.Whenever possible, idiomatic Penzai models should not contain Python conditional branches in their
__call__. You should be able to JIT-compile the__call__of any model, and there should generally be only a single control flow path through it.
Penzai uses this convention because it makes it straightforward to compose layers with each other. For instance, there’s an unambiguous way to run two layers in order: pass the output of the first layer as the positional input of the second, and pass the same side inputs to both layers.
This means that, instead of passing configuration data as arguments to the forward pass of each layer, most configuration is directly attached to the layer itself:
Configuration metadata, such as the input or output axis names for
pz.nn.Linear, the activation function forpz.nn.Elementwise, or the name of a dynamic side input, are stored as attributes on the layer, and set when the layer is initially built.Different “modes” of computation, such as “whether or not we should enable dropout” or “whether we are doing scoring or autoregressive decoding”, are usually represented as different classes. This makes sure that the number of configuration attributes is small, and that the implementation of each layer is simple. You can then swap out model components using
pz.selectto switch between different model behaviors.
The emphasis on “doing one thing” also extends to composite layers. In Penzai, composite layers are usually defined as direct compositions of simpler layers, by subclassing the pz.nn.Sequential combinator. Then, their responsibility at runtime is just to call their children in sequence, which means it’s easy to insert new logic without interfering with the model’s computation. We’ve already seen an example of this: the MLP model and Affine blocks in our mlp are both subclasses of pz.nn.Sequential.
More complex combinators also tend to adhere to this pattern. For instance, the core Attention block in Penzai is purely a dataflow combinator, defined as
@struct.pytree_dataclass
class Attention(Layer):
input_to_query: Layer
input_to_key: Layer
input_to_value: Layer
query_key_to_attn: Layer
attn_value_to_output: Layer
def __call__(self, x: NamedArray, **side_inputs) -> NamedArray:
query = self.input_to_query(x, **side_inputs)
key = self.input_to_key(x, **side_inputs)
value = self.input_to_value(x, **side_inputs)
attn = self.query_key_to_attn((query, key), **side_inputs)
output = self.attn_value_to_output((attn, value), **side_inputs)
return output
All of the specific logic of computing positional embeddings, applying attention masks, and computing the softmax weights are left to the child layers, which makes it easy to go in and capture intermediates or intervene on their behaviors at any point, without needing to change the attention implementation. Attention itself just does a single thing: manage the routing of data between the different components, during training or scoring mode.
If you want to do autoregressive decoding, you can swap out Attention blocks for KVCachingAttention blocks using something like
(
pz.select(model)
.at_instances_of(pz.nn.Attention)
.apply(lambda attn: pz.nn.KVCachingAttention.from_uncached(attn, **kwargs))
)
This produces a copy of your model that additionally manages and updates KV caches, while still supporting arbitrary child layer logic and without changing any of the rest of your model.
5. Configuration Happens During Construction (Not __call__)#
As discussed above, Penzai layers avoid passing configuration arguments at runtime, and avoid making assumptions about their child layers and parameters as much as possible. However, it’s still important for layers and models to be able to configure themselves and initialize their parameters. In Penzai, all of this happens when the layers are initially constructed.
By convention, Penzai layers configure themselves using a class method, often called from_config(cls, name: str, init_base_rng, ...). from_config, in turn, takes all of the configuration arguments that are necessary to initialize the model, and uses them to set up their sublayers and parameters. Penzai layers usually do NOT override __init__, so that it’s easy to bypass the initialization logic and rebuild models with different attributes.
We can see this by calling the from_config method of simple_mlp.MLP:
mlp = simple_mlp.MLP.from_config(
name="mlp",
init_base_rng=jax.random.PRNGKey(1),
feature_sizes=[8, 32, 32, 8],
activation_fn=jax.nn.gelu,
)
mlp
Notice that the arguments to from_config aren’t actually stored on the MLP itself. Instead, they are simply used to configure and set up the list of sublayers. In general, the configuration arguments of complex models will often “vanish” in this way after the model is initially built.
In fact, all of the custom logic of MLP and Affine is defined in the from_config methods, not __call__. Once initialized, you are free to remove them entirely without affecting the behavior of the model. For instance, we can replace the MLP class with a basic Sequential and get the same behavior:
pz.nn.inline_groups(
pz.nn.Sequential([mlp]),
parent_filter=lambda _: True,
child_filter=lambda _: True,
)
This pattern also applies to layers that are designed for hot-swapping. For instance, the KVCachingAttention block defines a classmethod .from_uncached that converts an Attention block into a KVCachingAttention, which takes ownership of the children of that Attention block and then discards the original block.
In general, it may be useful to think of a Penzai model as a “declarative” list of steps in the model’s forward pass. If different configurations run different steps, they are usually represented as models with different structures.
By convention, layer builders like from_config follow the signature
def from_config(cls, name: str, init_base_rng: jax.Array | None, ...):
...
The name argument is used to ensure that all parameters have unique names, and the init_base_rng determines how to initialize the parameters:
If
init_base_rngis a JAX PRNGKey, it is combined with thenameargument to initialize the parameter randomly. The resulting model will contain aVariablefor each parameter.If
init_base_rngisNone, parameter initialization is skipped, and the resulting model will instead contain aVariableSlotfor each parameter. This can be useful for loading pretrained models from checkpoints instead of initializing them from scratch.
To make this work:
Layers that contain other sublayers should give them unique names by adding a suffix to their own name, e.g. passing
name=f"{name}/Linear_0"to their child. Theinit_base_rngshould be forwarded to sublayers unchanged.Layers that directly initialize parameters should use the helper function
pz.nn.make_parameter, which implements the above logic and ensures parameters with different names are initialized differently, even with the sameinit_base_rng.
6. Layers Use Named Axes (Via Lifted Positional Operations)#
Axis ordering can make it harder to reason about what complex models are doing, especially when trying to visualize or intervene on internal activations, or when using models from an unfamiliar codebase. It’s often easier to refer to axes by name. But you shouldn’t have to learn a whole new array API just to use named axes; the existing Numpy and JAX APIs are pretty good!
Penzai strikes a middle ground using a lightweight locally-positional named-axis system, defined in a single file and with a minimal API surface. In short:
The
pz.nx.NamedArrayclass wraps an ordinary array, and assigns each axis to either a position or a name (but not both).You can convert positional axes to named ones using
.tag(...), or convert named axes back to positional axes using.untag(...).Any JAX function can be lifted using
pz.nx.nmap. The lifted function will act normally over the positional axes but will be automatically vectorized over all of the named axes (usingjax.vmapinternally). OnlyNamedArrayarguments are processed in this way; other arguments are just passed through.Standard array methods and operators (e.g.
.sum(),+, or slicing) are also lifted so that they operate over positional axes and vectorize over named axes.By convention, Penzai layers use axis names to define their interface, but then use
.untag,nmap, and.tagto implement their internal logic.
For instance, here’s how you might take a softmax over a vocabulary axis:
# Start with a JAX array:
array = jax.random.normal(jax.random.key(0), [8, 32])
# Wrap it as a named array:
wrapped = pz.nx.wrap(array)
# Assign names:
named = wrapped.tag("batch", "vocabulary")
# Visualize it:
named
# Un-tag the vocabulary axis:
untagged = named.untag("vocabulary")
# Map the ordinary JAX softmax function over the temporary positional axis:
softmaxed = pz.nx.nmap(jax.nn.softmax)(untagged, axis=0)
# Tag the positional axis with a name again:
softmaxed.tag("vocabulary")
And here’s how you might wrap that in an idiomatic layer, which has a named-axis interface:
@pz.pytree_dataclass
class Softmax(pz.nn.Layer):
axis_name: str = dataclasses.field(metadata={"pytree_node": False})
def __call__(self, arg, **unused_side_inputs):
# Write the logic as if the argument is one dimensional:
arr = arg.untag(self.axis_name)
assert len(arr.positional_shape) == 1
result = pz.nx.nmap(jax.nn.softmax)(arr, axis=0)
# Then re-bind names at the end:
return result.tag(self.axis_name)
layer = Softmax("vocabulary")
layer
layer(named)
Because everything vectorizes over names by default, Penzai models can usually be used with arbitrary numbers of batch axes at runtime as long as you give them unique names. You can even insert new layers that manipulate specific batch axes by name (e.g. copying activations from one input to another), without interfering with any of the shapes in the rest of your model.
Putting It All Together: A Basic Penzai Neural Network#
To show how these principles interact, here’s how we might implement a neural network from scratch in Penzai. We’ll focus on re-implementing a basic MLP (like the running example above), and omit a few advanced features to keep things simple.
An MLP is composed of a sequence of steps, including linear operations, biases, and elementwise activations. We can implement each of these using a separate layer so that we can manipulate them after the model is built, and define each using named axes:
@pz.pytree_dataclass
class SimpleLinear(pz.nn.Layer):
"""A simple linear layer with a single input/output axis."""
# Parameters are annotated as `ParameterLike` to allow swapping them out after
# initialization.
kernel: pz.nn.ParameterLike[pz.nx.NamedArray]
# Non-Pytree fields (which are not arraylike) should be annotated as such to
# tell JAX not to try to convert them:
features_axis: str = dataclasses.field(metadata={"pytree_node": False})
def __call__(
self, x: pz.nx.NamedArray, /, **unused_side_inputs
) -> pz.nx.NamedArray:
"""Multiplies the input by the learned kernel."""
# pos_x has one positional axis
pos_x = x.untag(self.features_axis)
# pos_kernel has two positional axes
pos_kernel = self.kernel.value.untag("out_features", "in_features")
# We can combine them using ordinary positional semantics:
pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
return pos_y.tag(self.features_axis)
@classmethod
def from_config(
cls,
name: str,
init_base_rng: jax.Array | None,
in_features: int,
out_features: int,
features_axis: str = "features",
) -> "SimpleLinear":
"""Constructs a linear layer from configuration arguments."""
def _initializer(key):
arr = jax.nn.initializers.xavier_normal()(
key, (out_features, in_features)
)
return pz.nx.wrap(arr).tag("out_features", "in_features")
return cls(
kernel=pz.nn.make_parameter(
name=f"{name}.kernel",
init_base_rng=init_base_rng,
initializer=_initializer,
),
features_axis=features_axis,
)
@pz.pytree_dataclass
class SimpleBias(pz.nn.Layer):
"""A simple bias layer."""
# The SimpleBias layer doesn't need to store its output axis name at all!
bias: pz.nn.ParameterLike[pz.nx.NamedArray]
def __call__(self, x: pz.nx.NamedArray, /, **unused_side_inputs) -> pz.nx.NamedArray:
"""Adds a bias to the input."""
return x + self.bias.value # Automatically vectorized!
@classmethod
def from_config(
cls,
name: str,
init_base_rng: jax.Array | None,
features: int,
features_axis: str = "features",
) -> "SimpleBias":
"""Constructs a bias layer from configuration arguments."""
return cls(
bias=pz.nn.make_parameter(
name=f"{name}.bias",
init_base_rng=init_base_rng,
initializer=lambda _: pz.nx.zeros({features_axis: features}),
),
)
@pz.pytree_dataclass
class SimpleElementwise(pz.nn.Layer):
"""A simple elementwise layer."""
fn: Callable[[jax.Array], jax.Array] = dataclasses.field(
metadata={"pytree_node": False}
)
def __call__(self, x: pz.nx.NamedArray, /, **unused_side_inputs) -> pz.nx.NamedArray:
"""Runs the activation function."""
return pz.nx.nmap(self.fn)(x)
# No need for `from_config`, since it would be the same as `__init__`.
We can then define a top-level MLP layer as a subclass of Sequential:
@pz.pytree_dataclass
class SimpleMLP(pz.nn.Sequential):
# sublayers is inherited from Sequential, but we restate it here for clarity.
sublayers: list[pz.nn.Layer]
# __call__ is inherited from Sequential, so no need to reimplement it! In
# fact, Sequential.__call__ is marked with @typing.final so you don't
# accidentally override it.
@classmethod
def from_config(
cls,
name: str,
init_base_rng: jax.Array | None,
feature_sizes: Sequence[int],
activation: Callable[[jax.Array], jax.Array] = jax.nn.relu,
features_axis: str = "features",
) -> "SimpleMLP":
"""Constructs a MLP with uninitialized parameters."""
# We build the steps of the forward pass in from_config, and push all
# configuration arguments down to the sublayers:
sublayers = []
for i in range(len(feature_sizes) - 1):
# We need to ensure parameter name uniqueness ourselves:
sublayers.append(SimpleLinear.from_config(
name=f"{name}/block_{i}/linear",
init_base_rng=init_base_rng,
in_features=feature_sizes[i],
out_features=feature_sizes[i + 1],
features_axis=features_axis,
))
sublayers.append(SimpleBias.from_config(
name=f"{name}/block_{i}/bias",
init_base_rng=init_base_rng,
features=feature_sizes[i + 1],
features_axis=features_axis,
))
if i < len(feature_sizes) - 2:
sublayers.append(SimpleElementwise(activation))
return cls(sublayers)
Building our model without an initialization PRNGKey just builds the structure:
SimpleMLP.from_config(
name="mlp",
init_base_rng=None,
feature_sizes=[8, 32, 32, 8],
activation=jax.nn.relu,
features_axis="features",
)
If we pass init_base_rng, it will also initialize the parameters as mutable Variable objects:
model = SimpleMLP.from_config(
name="mlp",
init_base_rng=jax.random.key(42),
feature_sizes=[8, 32, 32, 8],
activation=jax.nn.relu,
features_axis="features",
)
model
We can call it with some example inputs to check that it works:
model(pz.nx.ones({"features": 8}))
Or set up a simple training loop:
from penzai.toolshed import basic_training
import optax
example_inputs = pz.nx.wrap(
jax.random.normal(jax.random.key(100), (100, 8))
).tag("batch", "features")
example_targets = pz.nx.wrap(
jax.random.normal(jax.random.key(101), (100, 8))
).tag("batch", "features")
def loss_fn(model, rng, state, current_input, current_target):
del rng, state # More complex training loops could use these if needed
model_out = model(current_input)
losses = pz.nx.nmap(jnp.square)(model_out - current_target)
loss = losses.untag("batch", "features").unwrap().sum()
return (loss, None, {"my_loss": loss})
trainer = basic_training.StatefulTrainer.build(
root_rng=jax.random.key(42),
model=model,
optimizer_def=optax.adam(0.01),
loss_fn=loss_fn,
)
outputs = []
while trainer.state.value.step < 1000:
out = trainer.step(
current_input=example_inputs,
current_target=example_targets,
)
if trainer.state.value.step % 20 == 0:
print(f"At {trainer.state.value.step}: {out}")
At 20: {'my_loss': Array(562.4356, dtype=float32)}
At 40: {'my_loss': Array(354.94263, dtype=float32)}
At 60: {'my_loss': Array(208.88644, dtype=float32)}
At 80: {'my_loss': Array(116.63308, dtype=float32)}
At 100: {'my_loss': Array(69.82226, dtype=float32)}
At 120: {'my_loss': Array(48.01413, dtype=float32)}
At 140: {'my_loss': Array(36.271587, dtype=float32)}
At 160: {'my_loss': Array(27.9783, dtype=float32)}
At 180: {'my_loss': Array(22.618397, dtype=float32)}
At 200: {'my_loss': Array(19.75925, dtype=float32)}
At 220: {'my_loss': Array(15.713261, dtype=float32)}
At 240: {'my_loss': Array(14.278907, dtype=float32)}
At 260: {'my_loss': Array(12.619974, dtype=float32)}
At 280: {'my_loss': Array(11.935994, dtype=float32)}
At 300: {'my_loss': Array(8.933923, dtype=float32)}
At 320: {'my_loss': Array(7.9672227, dtype=float32)}
At 340: {'my_loss': Array(6.886015, dtype=float32)}
At 360: {'my_loss': Array(7.233006, dtype=float32)}
At 380: {'my_loss': Array(5.6690035, dtype=float32)}
At 400: {'my_loss': Array(6.6670713, dtype=float32)}
At 420: {'my_loss': Array(5.707902, dtype=float32)}
At 440: {'my_loss': Array(5.3238034, dtype=float32)}
At 460: {'my_loss': Array(5.2755117, dtype=float32)}
At 480: {'my_loss': Array(3.6372209, dtype=float32)}
At 500: {'my_loss': Array(4.060416, dtype=float32)}
At 520: {'my_loss': Array(3.7824037, dtype=float32)}
At 540: {'my_loss': Array(3.8496475, dtype=float32)}
At 560: {'my_loss': Array(3.2768314, dtype=float32)}
At 580: {'my_loss': Array(3.402127, dtype=float32)}
At 600: {'my_loss': Array(1.9185027, dtype=float32)}
At 620: {'my_loss': Array(2.0824428, dtype=float32)}
At 640: {'my_loss': Array(2.2999728, dtype=float32)}
At 660: {'my_loss': Array(3.2461543, dtype=float32)}
At 680: {'my_loss': Array(2.2447917, dtype=float32)}
At 700: {'my_loss': Array(2.2864742, dtype=float32)}
At 720: {'my_loss': Array(1.6309336, dtype=float32)}
At 740: {'my_loss': Array(3.654787, dtype=float32)}
At 760: {'my_loss': Array(1.7311825, dtype=float32)}
At 780: {'my_loss': Array(2.9659948, dtype=float32)}
At 800: {'my_loss': Array(4.086928, dtype=float32)}
At 820: {'my_loss': Array(2.963339, dtype=float32)}
At 840: {'my_loss': Array(1.3671762, dtype=float32)}
At 860: {'my_loss': Array(1.7845328, dtype=float32)}
At 880: {'my_loss': Array(1.1663316, dtype=float32)}
At 900: {'my_loss': Array(2.9988499, dtype=float32)}
At 920: {'my_loss': Array(2.2414417, dtype=float32)}
At 940: {'my_loss': Array(2.2406294, dtype=float32)}
At 960: {'my_loss': Array(0.8716761, dtype=float32)}
At 980: {'my_loss': Array(1.9427958, dtype=float32)}
At 1000: {'my_loss': Array(1.9259623, dtype=float32)}
Summary#
You now know everything you need to get started with neural networks in Penzai!
Penzai strives to enable complex modifications and interventions on models either before or after training them, without getting in your way. Following the principles described here is a recommended starting point and a great way to take advantage of all of Penzai’s tooling, but it’s not strictly enforced! You’re free to use Penzai’s visualization and patching tools with non-Penzai models, or define your own callable PyTree components without conforming to the pz.nn.Layer interface, if that makes more sense for your use case.