Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Gemma From Scratch: Inspecting and Re-implementing Gemma with Penzai#

Penzai includes a number of general-purpose tools for analyzing JAX neural networks. It also includes a declarative neural-network library designed to take advantage of those tools. This notebook demonstrates how to apply this tooling to a real-world neural network: the Gemma pretrained transformer model, implemented in Flax.

You might benefit from reading this notebook if any of these apply:

  • You are interested in learning about Penzai’s design principles, especially if you are already familiar with Flax.

  • You want to reverse-engineer a model that is currently written in Flax using Penzai’s tools.

  • You want to implement a model in Penzai, and would like to learn about the best practices for model development.

  • You want to learn more about the Gemma implementation included in Penzai, which is used by the other tutorial notebooks.

This notebook is broken into three main sections plus a setup section.

  • In Section 0, we set up the environment and load the Gemma weights for further analysis.

  • In Section 1, we show how to apply Penzai’s analysis and visualization tooling to the official Flax implementation of Gemma, and how to convert it into a Penzai-compatible form. We also discuss the high-level differences between Flax and Penzai.

  • In Section 2, we break down Gemma’s training/scoring mode into its constituent pieces, and show how to re-implement each of those pieces in idiomatic Penzai style. We also discuss how Penzai enables easy inspection of intermediate values, and use intermediate values to test the implementation of each piece.

  • In Section 3, we apply the same decomposition to the stateful key-value-caching mode, and demonstrate how stateful operations are represented in Penzai models using Penzai’s “data effects” system. We use this to implement a simple JIT-compiled sampler that still supports patching intermediate activations.

Section 0: Setup#

In this section, we’ll start by setting up our environment and loading the Gemma model.

Setting up the environment#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
try:
  import gemma
except ImportError:
  !pip install "gemma @ git+https://www.github.com/google-deepmind/gemma.git"
from __future__ import annotations
from typing import Any

import os
import dataclasses
import traceback
import functools
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import chex

import flax.linen
from jax.experimental import mesh_utils
import sentencepiece as spm
import gemma.params
import gemma.sampler
import gemma.transformer
import penzai
from penzai import pz
import penzai.toolshed.unflaxify
import penzai.toolshed.isolate_submodel

Loading Gemma#

Next we can load the Gemma model, using its official Flax reference implementation. We’ll use Gemma 2B for this notebook.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit https://www.kaggle.com/ and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
ckpt_path = os.path.join(weights_dir, '2b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax.

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
structure = checkpointer.metadata(ckpt_path)

sharding = jax.sharding.SingleDeviceSharding(jax.local_devices()[0])
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array, sharding=sharding
    ),
    structure,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)
params = gemma.params.nest_params(
    gemma.params.param_remapper(flat_params)
)
del flat_params
flax_gemma_config = gemma.transformer.TransformerConfig.from_params(
    params, cache_size=1024
)
flax_gemma = gemma.transformer.Transformer(flax_gemma_config)

Section 1: Analyzing the Flax model with Penzai#

In this section, we’ll give an overview of Penzai’s analysis and visualization tooling and demonstrate how to apply it to existing Flax models.

Looking at the model and its weights#

Penzai ships with a powerful pretty-printer and array visualizer designed to help you quickly navigate through and understand the structure of large trees. If you’ve used Colab or Jupyter before, you may be familiar with printouts that look like this:

print(repr(params))

Penzai provides an interactive alternative, treescope, which allows you to interactively fold and unfold children of deep trees like this. Try clicking on the gray triangle markers to expand or contract subtrees!

pz.show(params)

Let’s turn on penzai.treescope as the default Colab/IPython:

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()

Now everything we return from a Colab cell will be interactively pretty-printed:

params["transformer"]["layer_0"]

penzai.treescope also includes an n-dimensional array visualizer, which can help you understand the shape and content of arrays at a glance. You can hover or click on the cells of the visualization to inspect individual array elements.

Note that, with the truncate=True argument, we automatically cut out the middle elements of each array to keep the visualization a reasonable size. This is similar to how printing out an array produces ... in the middle of array printouts for large arrays.

pz.ts.render_array(params["transformer"]["layer_0"]["attn"]["q_einsum"]["w"], truncate=True)

Since we ran register_autovisualize_magic above, we can also automatically visualize arrays whenever we return something from a Colab cell using the %%autovisualize magic command. Treescope will automatically insert these visualizations inside the rendered tree itself and let you expand them as desired. Try clicking the triangles to look at different weights!

%%autovisualize
params["transformer"]["layer_0"]

We can also use treescope to print out the model itself. However, Flax models don’t show much about themselves when you construct them. In this case, the transformer model is a Python dataclass whose attributes are just it’s configuration:

flax_gemma

(You get something similar if you print it out without Treescope:)

print(flax_gemma)

Unfortunately, you can’t directly see the individual layers inside the Gemma model here, because in Flax those layers aren’t actually built until you call apply on the model and bind it to parameters. But as we will see later, we can use Penzai to get a better look at the internals of the Gemma model.

Looking at model inputs and outputs#

Let’s run the model on some example text. We’ll start by tokenizing it:

example_input = "Penzai includes a number of general-purpose tools for analyzing JAX neural networks. It also includes a declarative neural-network library designed to take advantage of those tools. This notebook demonstrates how to apply this tooling to a real-world neural network: the Gemma pretrained transformer model."
print(example_input)
tokens = jnp.array([vocab.bos_id()] + vocab.EncodeAsIds(example_input))
tokens

We can apply treescope’s array visualizer to tokens too! Discrete data is shown using colors, with stripes used for numbers with a lot of digits. (Fun fact: since sentencepiece tokenizers tend to give lower IDs to more common tokens, more common tokens tend to have simpler-looking visualizations.)

pz.ts.render_array(tokens)

In fact, we can even pass the tokenizer to the autovisualizer, in which case hovering or clicking on array elements will tell you what token each ID is for. Try it below!

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokens

Now let’s call the model on it (adapting the logic from this notebook):

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = gemma.transformer.build_positions_from_mask(pad_mask)
  attention_mask = gemma.transformer.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask
%%autovisualize
positions, attention_mask = get_attention_mask_and_positions(tokens[None, :], vocab.pad_id())

flax_gemma_output, new_vars = flax_gemma.apply(
    {'params': params['transformer']},
    tokens[None, :],
    positions,
    None, # Attention cache is None.
    attention_mask,
)
assert new_vars is None
flax_gemma_output

This visualization shows up in mostly red, because most of Gemma’s output logits are negative. We can map this to a probability distribution using softmax:

%%autovisualize
jax.nn.softmax(flax_gemma_output)

Let’s find the most-likely prediction at each position, and compare it to the actual tokens:

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
predictions = jnp.argmax(flax_gemma_output, axis=-1)
jnp.stack([predictions[0, :-1], tokens[1:]])
list(zip([vocab.IdToPiece(tok) for tok in predictions[0].tolist()], [vocab.IdToPiece(tok) for tok in tokens[1:].tolist()]))

Looking inside the Flax model with Flax utilities#

How can we figure out what is happening inside Gemma? Of course, we can look at the code to see how it’s implemented, but what if we want to see intermediate activations, or inspect how data flows between Flax modules?

Flax does include a few utilities for this, which are described in the Flax guides and don’t require using Penzai. One option is to use “tabulate” to list out all of the submodule calls:

print(flax_gemma.tabulate(
    jax.random.key(42),
    tokens[None, :],
    positions,
    None, # Attention cache is None.
    attention_mask,
    console_kwargs={"width": 120}
))

Another option is to use capture_intermediates to return intermediate activations:

flax_gemma.apply(
    {'params': params['transformer']},
    tokens[None, :],
    positions,
    None, # Attention cache is None.
    attention_mask,
    capture_intermediates=True,
)

Flax also includes an advanced “intercept_methods” utility which allows you to intercept module calls and apply custom logic.

Flax models v.s. Penzai models#

What if we want to do more complex operations, like looking at the inputs passed to the submodules, changing the output of submodules, or extracting and running submodules individually? This is possible using flax.linen.intercept_methods, but it can be somewhat difficult to reason about. An alternative is to convert the Flax model to a Penzai model, and then use Penzai’s tree-rewriting tools to visualize things. To this end, Penzai provides a utility unflaxify which recursively intercepts every Flax method call and encapsulates it into an equivalent Penzai layer.

Before we show how this works, let’s briefly pause to discuss the differences between Penzai models and Flax models, and the overall differences between the Flax and Penzai conventions and design ideas.

From the “Flax philosophy” documentation, Flax aims to “offer an API familiar to those experienced with Keras/Sonnet/PyTorch” with “an implicit variable management API to save the user from having to manually thread thousands of variables through a complex tree of functions.” To this end, Flax modules are defined as if they own stateful variables and parameters, which they can modify imperatively, and Flax runs logic under-the-hood to transform this stateful view into a functional computation that works with JAX. Flax modules generally look something like this:

Initializer = jax.nn.initializers.Initializer

class SimpleFlaxDense(flax.linen.Module):
  features: int
  kernel_init: Initializer = flax.linen.initializers.lecun_normal()
  bias_init: Initializer = flax.linen.initializers.zeros_init()

  @flax.linen.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # Shape info.
    y = jnp.dot(inputs, kernel)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

class SimpleFlaxMLP(flax.linen.Module):
  out_dims: int

  @flax.linen.compact
  def __call__(self, x):
    x = SimpleFlaxDense(128)(x)
    x = flax.linen.relu(x)
    x = SimpleFlaxDense(self.out_dims)(x)
    return x
SimpleFlaxMLP(out_dims=32)
flax_mlp_params = SimpleFlaxMLP(out_dims=32).init(jax.random.key(10), jnp.ones((4, 8)))
flax_mlp_params
SimpleFlaxMLP(out_dims=32).apply(flax_mlp_params, jnp.ones((4, 8)))

This approach makes Flax a great choice for quickly writing neural networks in JAX, especially if you are already familiar with stateful neural network libraries and object-oriented programming.

However, since this representation was designed for writing model architectures, it’s not necessarily the best choice for analyzing or patching those models. Indeed, any such analysis must be designed to work around the transformation from stateful object-oriented modules to functional JAX-compatible method calls.

Penzai, on the other hand, prioritizes analysis, visualization, and patchability. One of the primary design goals for Penzai’s neural net library penzai.nn is to be a declarative system where “what you see is what you get”: you should be able to immediately see what your model is going to do when you call it, and you should be able to “reach in” and change what it does. This leads to a number of concrete differences:

  • Parameters:

    • In Flax, you define parameters by calling self.param, and you can define other mutable variables using self.variable. This variable is implicitly inserted into a parameter dictionary and retrieved when a module is functionalized.

    • In Penzai, parameters are simply stored as attributes on the layer that owns them. You can walk the tree of layers to extract the parameters and put them in a dictionary if you want, and Penzai makes it easy to do this, but Penzai itself doesn’t require it. This also means you can just look at the parameters by looking at your model object.

      • If you’re familiar with Equinox, Penzai layers are PyTrees in the same way that Equinox models are, so you can just pass them through JAX transformations without issues. But one difference from Equinox is that all parameters are explicitly marked with a Parameter class, and the best practice is to filter for Parameter instances instead of implicitly assuming all float-dtype arrays are parameters.

  • Submodules:

    • In Flax, when you instantiate a new module inside another module’s setup or @compact method, that module is implicitly attached to the containing module, and given its own sub-dictionary of parameters. Each module instance remembers which parameters belong to it in an internal scope attribute and looks them up as needed.

    • In Penzai, sublayers are simply stored as attributes on the layer that owns them. Each layer already owns its own parameters, so there’s no need to do a functionalization step.

  • Module construction and submodule configuration:

    • Flax explicitly supports a compact configuration style, where submodules are implicitly configured at the same place where they are called, directly in Python code. This makes it easy to write the model, but somewhat difficult if you want to change a small part of an existing model, since the submodules may not even exist until they are called.

    • In Penzai, there is a strict separation between configuring a model and calling it, like Equinox and PyTorch. In fact, Penzai goes even further than other frameworks and tries to be as permissive as possible about sublayers after models are configured. Most control flow in Penzai models is explicitly represented in the model’s structure using general-purpose combinators like pz.nn.Sequential or pz.nn.Residual, rather than being implemented as code. This makes it slightly more verbose to write and configure the model, but makes it much easier to visualize and patch it afterward.

  • Mutable state and random numbers:

    • Flax tries to provide a familiar object-oriented interface to stateful operations when the model runs, including mutable variables as part of the core abstraction. Every module has built-in support for variables and states, with a fixed API. Because of this, every module has to know about the variables and states of their submodules, and every top-level module has to be transformed using .apply in order to be called in a functional way.

    • Penzai intentionally avoids baking mutable variables or random numbers into the core system, and in fact doesn’t have any equivalent of apply; Penzai models are purely-functional by default. You can directly call methods on your model without doing any wrapping. (This is again a shared feature with Equinox.)

      • Penzai does allow you to opt in to mutable state or stateful random number generation, however, using a “data-effects” system heavily inspired by effect systems in functional programming languages. This system works by rewriting your model tree directly: effects are stored as attributes of your model, and effect handlers inject new values for those attributes using functional type-dependent tree traversals. This means you can freely modify how effects are interpreted, e.g. by “freezing” mutable states to specific values, or intercepting state updates.

      • Penzai tries very hard to avoid having any hidden or implicit state when your model runs. All mutable states and other side effects are accessed through your model’s explicitly-declared attributes, and you can always see them by printing out your model tree in treescope.

      • Penzai also tries very hard to avoid changing Python semantics, and doesn’t secretly override class construction, modify dataclass attributes, or wrap your instance methods (although in some cases it asks you to explicitly wrap them yourself). You shouldn’t have to learn a new dialect of Python to understand Penzai code.

A direct Penzai equivalent of the Flax dense layer and MLP defined above might look like this:

@pz.pytree_dataclass
class SimplePzDense(pz.Layer):
  kernel: pz.nn.ParameterLike[jax.Array]
  bias: pz.nn.ParameterLike[jax.Array]

  @pz.checked_layer_call
  def __call__(self, inputs):
    """Calls the dense layer."""
    y = jnp.dot(inputs, self.kernel.value)
    y = y + self.bias.value
    return y

  @classmethod
  def from_config(
      cls,
      in_features: int,  # <- Requires passing explicit input feature dimension
      out_features: int,
      kernel_init: Initializer = flax.linen.initializers.lecun_normal(),
      bias_init: Initializer = flax.linen.initializers.zeros_init(),
  ) -> SimplePzDense:
    """Builds the dense layer from configuration."""
    kernel = pz.nn.UninitializedParameter(
        initializer=lambda key: kernel_init(key, (in_features, out_features,)),
        name="kernel"
    )
    bias = pz.nn.UninitializedParameter(
        initializer=lambda key: bias_init(key, (out_features,)),
        name="bias"
    )
    return cls(kernel=kernel, bias=bias)

  # Optional shape-checking methods:
  def input_structure(self):
    in_features, _ = self.kernel.value_structure.shape
    return pz.chk.ArraySpec(
        (*pz.chk.var("batch"), in_features), dtype=jnp.floating)
  def output_structure(self):
    _, out_features = self.kernel.value_structure.shape
    return pz.chk.ArraySpec(
        (*pz.chk.var("batch"), out_features), dtype=jnp.floating)


@pz.pytree_dataclass
class SimplePzMLP(pz.nn.Sequential):
  sublayers: list[pz.LayerLike]

  # __call__ is inherited from pz.nn.Sequential and just runs the children in
  # sequence.

  @classmethod
  def from_config(cls, in_dims: int, out_dims: int) -> SimplePzMLP:
    # Penzai doesn't automatically make parameter names unique; you are in
    # charge of naming your parameters.
    return cls(sublayers=[
        pz.nn.add_parameter_prefix(
            "SimplePzDense_0", SimplePzDense.from_config(in_dims, 128)
        ),
        pz.nn.Elementwise(jax.nn.relu),
        pz.nn.add_parameter_prefix(
            "SimplePzDense_1", SimplePzDense.from_config(128, out_dims)
        ),
    ])
pz_mlp_def = SimplePzMLP.from_config(in_dims=8, out_dims=32)

Printing out this model shows you its structure, even before we initialize the parameters:

pz_mlp_def

We can initialize parameters by finding all of the UninitializedParameters and calling their initializers:

with pz.RandomStream(jax.random.key(42)) as stream:
  pz_mlp = (
      pz.select(pz_mlp_def)
      .at_instances_of(pz.nn.UninitializedParameter)
      .apply(lambda param: param.initialize(stream.next_key()))
  )
pz_mlp
pz_mlp(jnp.ones((4, 8)))

But Penzai provides utilities to do most common tasks for you:

pz_mlp = pz.nn.initialize_parameters(pz_mlp_def, jax.random.key(42))
pz_mlp

And Penzai provides powerful tools for traversing this structure to do arbitrary transformations. For instance, you can pull out a parameter dictionary:

param_dict = {
    param.name: param
    for param in pz.select(pz_mlp).at_instances_of(pz.nn.Parameter).get_sequence()
}
param_dict

Or substitute initialized parameters back into an uninitialized model:

pz.select(pz_mlp_def).at_instances_of(pz.nn.UninitializedParameter).apply(lambda param: param_dict[param.name])

Or even insert new logic to do arbitrary things while the model runs:

@pz.pytree_dataclass
class HelloWorld(pz.Layer):
  def __call__(self, inputs):
    pz.show("Hello world! My intermediate value is:", inputs)
    return inputs

hello_world_mlp = (
    pz.select(pz_mlp).at_instances_of(pz.nn.Elementwise)
    .insert_after(HelloWorld())
)
hello_world_mlp
hello_world_mlp(jnp.ones((4, 8)))

Note that none of these operations actually modify the original pz_mlp_def! Instead, each of these operations makes a copy with those modifications applied. This means you can safely make complex destructive modifications to your model, because you’re only modifying a copy. (If you’re familiar with JAX’s at[...].set(...) notation, this is basically the same idea extended to full model trees.)

As a sidenote, this isn’t quite an idiomatic Penzai model yet. A more idiomatic version would split up SimpleDense into separate Linear and AddBias sublayers, and use Penzai’s named axis system so that the meaning of the different axis indices is more obvious. Here’s an equivalent but more idiomatic example MLP from penzai.example_models:

%%autovisualize
import penzai.example_models.simple_mlp
idiomatic_mlp = pz.nn.initialize_parameters(
    penzai.example_models.simple_mlp.MLP.from_config(feature_sizes=[8, 128, 32]),
    jax.random.key(42),
)
idiomatic_mlp
%%autovisualize
idiomatic_mlp(pz.nx.wrap(jnp.ones((4, 8)), "my_batch_dim", "features"))

Most of the layers in this MLP are subclasses of pz.nn.Sequential. We can easily flatten them without changing the behavior of the model:

flat_mlp = pz.nn.inline_groups(
    pz.nn.Sequential([idiomatic_mlp]),
    parent_filter=lambda _: True, child_filter=lambda _: True)
flat_mlp

Looking inside Flax models with Penzai tools#

Luckily, we don’t need to manually rewrite every Flax model to take advantage of Penzai’s tooling. Penzai ships with a utility penzai.toolshed.unflaxify, which uses Flax’s intercept_methods hook to transform a Flax model into an equivalent Penzai one. This converted model isn’t usually very idiomatic, but since it fits with Penzai’s conventions, you can use Penzai’s tools to analyze and visualize it.

Let’s start by trying it with the Flax MLP:

from penzai.toolshed import unflaxify
intercepted_flax_mlp = unflaxify.unflaxify_apply(
    SimpleFlaxMLP(out_dims=32),
    flax_mlp_params,
    jnp.ones((4, 8))
)
intercepted_flax_mlp

unflaxify intercepts every module method call, and transforms it into a dataclass object that holds its own parameters and manages its own state. For instance, the two SimpleFlaxDense method blocks each are responsible for their own parameters.

Calling the intercepted methods still works as you’d expect, although we need to wrap the input argument, because by convention Penzai layers always take a single tree as input instead of taking multiple arguments.

%%autovisualize
intercepted_flax_mlp(unflaxify.ArgsAndKwargs.capture(jnp.ones((4, 8))))

But we are free to manipulate the tree structure to change this behavior. For instance, let’s insert layers that print out intermediates:

def add_intermediate_loggers(layer):
  """Recursively add loggers to this layer and all its sublayers."""
  return pz.nn.Sequential([
      HelloWorld(),
      (
          pz.select(layer).at_children().at_instances_of(pz.Layer)
          .apply(add_intermediate_loggers)
      ),
      HelloWorld(),
  ])

logging_flax_mlp = (
    pz.select(intercepted_flax_mlp).at_instances_of(pz.Layer)
    .apply(add_intermediate_loggers)
)
pz.select(logging_flax_mlp).at_instances_of(HelloWorld).show_value()
%%autovisualize
logging_flax_mlp(unflaxify.ArgsAndKwargs.capture(jnp.ones((4, 8))))

In fact, Penzai has a utility to quickly get an overview of the intermediate values in a computation by directly interleaving them into the model’s structure:

%%autovisualize
from penzai.toolshed import interleave_intermediates
interleaved = interleave_intermediates.run_and_interleave_intermediates(
    intercepted_flax_mlp,
    unflaxify.ArgsAndKwargs.capture(jnp.ones((4, 8)))
)

# Helper to expand the interesting parts of the visualiation:
pz.select(interleaved).at_instances_of(
    interleave_intermediates.IdentityWithSavedActivations
).at(lambda x: x.saved_activations[0]).show_value()

We can apply this same tooling to the Gemma implementation to get an overview of the main components. For simplicity, we’ll focus on the scoring mode (no KV cache) for now. Try clicking around to explore the model’s structure!

%%autovisualize

intercepted_gemma = unflaxify.unflaxify_apply(
    flax_gemma,
    {'params': params['transformer']},
    tokens[None, :],
    positions,
    None,  # Attention cache is None.
    attention_mask,
)

intercepted_gemma

(You might notice the WithConstantSideInputs wrapper, which holds the embedding parameters. This is how Penzai handles parameters that need to be shared between multiple parts of a model. Layers that need access to those shared parameters use SharedParameterLookup markers to request access to them, and the WithConstantSideInputs substitutes those parameters wherever they are needed when the full model is called. See the data-effects tutorial for more information.)

Now that we’ve exposed the submodule method calls in our model’s structure, we can use Penzai tooling to inspect parts of the model. For instance, let’s capture the intermediate values before and after one of the attention layers. We start by clicking the “copy” symbol next to one of the layers we are interested in, which copies the following string:

(lambda root: root.body.submodule_calls[(3, 'layer_2.__call__')].submodule_calls[(1, 'attn.__call__')])

We can then pass that to pz.select(...).at(...) to select that layer:

attn_selection = pz.select(intercepted_gemma).at(
    (lambda root: root.body.submodule_calls[(3, 'layer_2.__call__')].submodule_calls[(1, 'attn.__call__')])
)
attn_selection

Now we use another Penzai utility to pull it out along with its activations:

example_gemma_wrapped_arg = unflaxify.ArgsAndKwargs.capture(
    tokens[None, :],
    positions,
    None,  # Attention cache is None.
    attention_mask,
)
from penzai.toolshed import isolate_submodel
captured = isolate_submodel.call_and_extract_submodel(
    attn_selection,
    example_gemma_wrapped_arg
)
%%autovisualize
captured

This lets us take a peek at the arguments that were passed to this attention layer, and also see the values it returned. We can even reproduce the output in a controlled setting by calling the submodel on the saved activations, which isn’t easy to do when using the Flax implementation alone:

%%autovisualize
captured.submodel(captured.saved_input)

To see the interpretation of these arguments and return values, we can cross-reference this with Gemma’s code:

class Attention(nn.Module):
  """Attention module."""

  <...>

  def __call__(
      self,
      x: jax.Array,
      segment_pos: jax.Array,
      cache: LayerCache | None,
      attn_mask: jax.Array,
  ) -> tuple[LayerCache | None, jax.Array]:
    <...>

    return new_cache, attn_output

Can we use this to look at the attention pattern itself? Unfortunately, not directly. Flax only allows you to intercept linen.Module method calls, and the attention pattern is computed directly in code, so there’s nothing to hook into. (In principle, we could use JAX’s Jaxpr tracing machinery to look inside this call, but this isn’t yet implemented in Penzai.)

In the next section, we’ll show how to re-implement Gemma in an idiomatic Penzai style. This will make it possible to look at the attention pattern and make other changes more directly.

Section 2: Porting the Gemma forward pass to Penzai#

To get the most out of Penzai’s analysis and visualization tools, we’d like to expose as much as possible in the model’s tree structure. This would enable us to insert new logic at any stage of the computation, and reduce the need for us to cross-reference the activations we see with the model code.

In this section, we’ll show how to re-implement each of the building blocks of Gemma as idiomatic Penzai layers, in a way that makes it easier to patch and inspect it afterward.

Feed-forward layer#

We’ll start with a relatively straightforward layer: the feedforward layer in each transformer block. Gemma’s feedforward layer uses GELU-based gated linear units (GEGLU), as proposed by Shazeer (2020). In Flax, Gemma’s feedforward layer is defined as:

class FeedForward(nn.Module):
  """Feed forward module."""
  features: int
  hidden_dim: int

  @nn.compact
  def __call__(self, x):
    w_gating = self.param(
        'gating_einsum',
        nn.initializers.zeros_init(),
        ((2, self.features, self.hidden_dim)),
    )
    ff_gate = jnp.dot(x, w_gating[0])
    gate_value = nn.gelu(ff_gate)

    ff1 = jnp.dot(x, w_gating[1])
    activations = gate_value * ff1

    w_linear = self.param(
        'linear',
        nn.initializers.zeros_init(),
        (self.hidden_dim, self.features),
    )
    outputs = jnp.dot(activations, w_linear)

    return outputs

The computation in this model is written as a sequence of Python operations. Unfortunately, that makes it hard to extract intermediate computations or patch them. In an idiomatic Penzai model, components like this should usually be broken down into smaller parts to make it easier to manipulate interactively.

We’ll make the following changes to port this to a Penzai layer:

  • The three dot products are fairly simple linear operations. We’ll rewrite them to use a standard Penzai building block for parameterized linear operations, pz.nn.Linear. We’ll also explicitly split the ‘gating_einsum’ parameter into two parameters, instead of having a single parameter and slicing it.

  • activations is computed as the product of two values that were computed independently. We’ll factor out this pattern into a general BranchAndMultiplyTogether combinator, which runs computations independently and then multiplies them together.

  • We’ll then define a Penzai equivalent of FeedForward as a subclass of pz.nn.Sequential, a standard Penzai combinator that just runs operations in sequence, and define a new classmethod from_config that initializes it. This pattern lets us associate configuration logic with a layer while preserving the ability to go in and insert new logic later.

  • We’ll configure all of these layers to use explicit axis names instead of axis indices. Penzai includes a lightweight “local” named axis system, where array dimensions can be referred to by either positional indices or names, and where it’s easy to “transpose” axes between the positional or named indexing patterns.

Here’s our implementation:

@pz.pytree_dataclass
class BranchAndMultiplyTogether(pz.Layer):
  branches: list[pz.LayerLike]

  def __call__(self, arg):
    if not self.branches:
      raise ValueError(
          'BranchAndMultiplyTogether requires at least one branch.'
      )

    running_product = self.branches[0](arg)
    for branch in self.branches[1:]:
      running_product *= branch(arg)

    return running_product
@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class GemmaFeedForward(pz.nn.Sequential):

  @classmethod
  def from_config(
      cls,
      embedding_dim: int,
      hidden_dim: int,
      dtype: jax.typing.DTypeLike = jnp.float32,
  ) -> GemmaFeedForward:
    return cls([
        BranchAndMultiplyTogether(branches=[
            pz.nn.NamedGroup("gate", [
              pz.nn.add_parameter_prefix("gating_linear",
                  pz.nn.Linear.from_config(
                      input_axes={"embedding": embedding_dim},
                      output_axes={"neurons": hidden_dim},
                      initializer=pz.nn.zero_initializer,
                      dtype=dtype,
                  ),
              ),
              pz.nn.Elementwise(jax.nn.gelu),
            ]),
            pz.nn.add_parameter_prefix("value_linear",
                pz.nn.Linear.from_config(
                      input_axes={"embedding": embedding_dim},
                      output_axes={"neurons": hidden_dim},
                      initializer=pz.nn.zero_initializer,
                      dtype=dtype,
                )
            ),
        ]),
        pz.nn.add_parameter_prefix("out_linear",
            pz.nn.Linear.from_config(
                input_axes={"neurons": hidden_dim},
                output_axes={"embedding": embedding_dim},
                initializer=pz.nn.zero_initializer,
                dtype=dtype,
            )
        ),
    ])

One thing to note about this implementation is that we explicitly add parameter prefixes to each child layer using pz.nn.add_parameter_prefix. Penzai does not automatically track variable scoping (since this usually requires some sort of implicit state management), and instead gives you full control of how the parameters in your model are named. Constructor methods like from_config should ensure that all parameter names in the submodels they return are locally unique; callers can then use pz.nn.add_parameter_prefix to ensure uniqueness when combining submodels.

(Aside: You might wonder, why have parameter names at all, if parameters are stored as attributes in the model? The main answer is that Penzai models are often patched and re-configured after they are constructed, so the specific location of the parameter in the model PyTree may change. Giving all parameters an explicit name means they have a stable identifier that persists even when parts of the model are extracted or replaced, making it easier to e.g. save and restore the parameters from checkpoints.)

Let’s build our layer and see how it looks:

pz_ff_def = GemmaFeedForward.from_config(
    embedding_dim=2048,
    hidden_dim=16384,
    dtype=jnp.bfloat16,
)
pz_ff_def

You may notice that the Penzai version’s printed representation looks a bit similar to the Flax module’s code. This is no accident! Idiomatic Penzai models follow the “what you see is what you get” (WYSIWYG) principle: the runtime behavior of a model should always be directly visible from printing it out in IPython.

Now let’s capture the intermediates for a feedforward layer in the Flax model, build the equivalent Penzai version, and compare their outputs.

captured_ff = isolate_submodel.call_and_extract_submodel(
    pz.select(intercepted_gemma).at((lambda root: root.body.submodule_calls[(1, 'layer_0.__call__')].submodule_calls[(3, 'mlp.__call__')])),
    example_gemma_wrapped_arg
)
captured_ff
captured_ff_params = {
    param.name: param.value for param in captured_ff.select().at_instances_of(pz.nn.Parameter).get_sequence()
}
ff_param_mapping = {
    'gating_linear.weights': pz.nx.NamedArray.wrap(
            captured_ff_params['layer_0.mlp.gating_einsum'][0]
        ).tag("embedding", "neurons"),
    'value_linear.weights': pz.nx.NamedArray.wrap(
            captured_ff_params['layer_0.mlp.gating_einsum'][1]
        ).tag("embedding", "neurons"),
    'out_linear.weights': pz.nx.NamedArray.wrap(
            captured_ff_params['layer_0.mlp.linear']
        ).tag("neurons", "embedding"),
}
pz_ff = (
  pz_ff_def
  .select().at_instances_of(pz.nn.UninitializedParameter)
  .apply(lambda param: param.initialize_with_value(
      ff_param_mapping[param.name], strict_dtype=False,
  ))
)
%%autovisualize
named_arg = pz.nx.NamedArray.wrap(
    captured_ff.saved_input.args[0]
).tag("batch", "seq", "embedding")
pz_ff(named_arg)
%%autovisualize
captured_ff.saved_output
chex.assert_trees_all_close(
    pz_ff(named_arg).unwrap("batch", "seq", "embedding"),
    captured_ff.saved_output,
)

Looks like they match!

Positional embeddings#

Next, lets look at positional embeddings. Gemma uses per-layer rotary positional embeddings (RoPE) as proposed by Su et al. (2021). In the Gemma Flax codebase, positional embeddings are just an ordinary Python function, since they don’t have any parameters:

def apply_rope(
    inputs: jax.Array,    # [B, L]
    positions: jax.Array, # [B, L]
    head_dim: int,
    max_wavelength: int = _MAX_WAVELENGTH,
) -> jax.Array:
  """Applies RoPE."""
  fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
  timescale = max_wavelength**fraction

  sinusoid_inp = (
      positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :]
  )
  sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :]
  sin = jnp.sin(sinusoid_inp)
  cos = jnp.cos(sinusoid_inp)

  first_half, second_half = jnp.split(inputs, 2, axis=-1)
  first_part = first_half * cos - second_half * sin
  second_part = second_half * cos + first_half * sin
  out = jnp.concatenate([first_part, second_part], axis=-1)
  return out.astype(inputs.dtype)

You might notice a lot of dimension wrangling! This implementation assumes that inputs has shape (batch, tokens, head_dim) and positions has shape (batch, tokens).

How should we represent this in the Penzai model? There are a few additional considerations to take into account:

  • The inputs and positions arguments depend on the training example, but the head_dim and max_wavelength arguments are configuration arguments. If we want to manipulate the positional embeddings in the same way as other parts of a Penzai model, we’d prefer to separate these.

  • Penzai models are built out of pytree_dataclasses so that JAX knows how to traverse them. In order to put this function into a Penzai combinator like Sequential, we’d like to ensure that this is also a pytree_dataclass, and that the head_dim and max_wavelength arguments are treated like part of the model structure, not as dynamic arrays.

  • The positions argument represents the position of each token in the sequence. This argument is constant throughout the entire model, across all of the layers, so we’d prefer not to have to thread it through every layer in the model. Also, by convention, Penzai layers always accept exactly one argument, usually produced by the previous layer.

  • By convention, Penzai layers do not make assumptions about the number of batch dimensions, and use named axes to refer to only the axes they care about.

We can address this by converting apply_rope into a non-parameterized Layer, treating inputs as the ordinary single-argument input for a Layer, using Penzai’s SideInputEffect to handle threading through the positions argument, and adapting it to use named axes. The resulting implementation is below:

from dataclasses import field

@pz.pytree_dataclass
class ApplyRoPE(pz.Layer):
  # The metadata annotations indicate that these values shouldn't be traversed
  # by JAX, and should always have concrete values instead of being traced.
  # (The actual PyTree flattening operations are defined by `pz.Struct`, the
  # superclass of `pz.Layer`, rather than being baked into pytree_datclass
  # itself.)
  embedding_axis: str = field(metadata={"pytree_node": False})
  max_wavelength: int = field(metadata={"pytree_node": False})

  # The `positions` attribute is a "side input". We expect some effect handler
  # to replace the value of this attribute with a concrete implementation.
  positions: pz.de.SideInputEffect[pz.nx.NamedArray]

  @classmethod
  def from_config(
      cls,
      positions_tag: Any,
      embedding_axis: str,
      max_wavelength: int = 10_000,
  ) -> "ApplyRoPE":
    return cls(
        embedding_axis=embedding_axis,
        max_wavelength=max_wavelength,
        positions=pz.de.SideInputRequest(tag=positions_tag),
    )

  def _apply_1d(self, input_slice: jax.Array, position: jax.Array) -> jax.Array:
    """Apply RoPE to a one-dimensional JAX array."""
    assert input_slice.ndim == 1
    assert position.ndim == 0
    # Infer `head_dim` from the input shape
    [head_dim] = input_slice.shape
    fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
    timescale = self.max_wavelength ** fraction
    # Since we're assuming `timescale` is a vector and `position` is a scalar,
    # we don't need any axis alignment.
    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    first_half, second_half = jnp.split(input_slice, 2)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    return jnp.concatenate([first_part, second_part])

  @pz.checked_layer_call
  def __call__(self, inputs: pz.nx.NamedArray) -> pz.nx.NamedArray:
    # SideInputEffect.ask() is how we retrieve a value from the effect handler.
    positions = self.positions.ask()

    # Embedding axis should not already be part of the positions.
    assert self.embedding_axis not in positions.named_shape
    # Every axis of the positions should appear in the inputs.
    assert not positions.positional_shape
    assert all(axis in inputs.named_shape for axis in positions.named_shape)

    # Unbind the embedding axis from the inputs, producing a 1-D
    # positional view.
    inputs_view = inputs.untag(self.embedding_axis)
    # Run the logic over our 1D view using `pz.nmap`, which vectorizes a
    # function over all named axes:
    out = pz.nx.nmap(self._apply_1d)(inputs_view, positions)
    # Finally, re-bind the embedding axis.
    out_named = out.tag(self.embedding_axis)
    return out_named.astype(inputs.dtype)

  # input_structure and output_structure are how we add optional shape and
  # structure annotations to our layer. These are checked by the
  # checked_layer_call decorator.
  def input_structure(self):
    return pz.chk.ArraySpec(
        named_shape={**pz.chk.var("B"), self.embedding_axis: pz.chk.var("F")},
        dtype=np.floating,
    )

  def output_structure(self):
    return self.input_structure()

We can look at it in treescope:

ApplyRoPE.from_config(positions_tag="positions", embedding_axis="embedding")

Note that the pytree-node fields are shown in italics, whereas the static fields are shown in a normal style. Treescope also helpfully tells us there’s an unhandled SideInputEffect, which means calling it directly will result in an error:

layer = ApplyRoPE.from_config(positions_tag="positions", embedding_axis="embedding")
try:
  layer(pz.nx.ones({"seq": 100, "embedding": 64}))
except pz.de.UnhandledEffectError:
  traceback.print_exc()

To provide a value for positions, we need to replace the SideInputRequest with a concrete value. We can do this using an effect handler:

handled_rope_layer = pz.de.WithSideInputsFromInputTuple.handling(
    ApplyRoPE.from_config(positions_tag="positions", embedding_axis="embedding"),
    tags=["positions"],
)
handled_rope_layer

The class method pz.de.WithSideInputFromArg.handling has replaced the SideInputRequest with a HandledSideInputRef, which indicates that it is responsible for providing the value. We can then call it with a tuple of inputs and positions:

%%autovisualize
handled_rope_layer((
    pz.nx.ones({"seq": 50, "embedding": 32}),
    pz.nx.arange("seq", 50),
))

Why go through the trouble of using a side input? It makes it easy to pass the same positions argument to multiple layers, without needing to thread it through every intermediate layer. For instance, we can do something like this:

handled_sequential = pz.de.WithSideInputsFromInputTuple.handling(
    pz.nn.Sequential([
        ApplyRoPE.from_config(positions_tag="positions", embedding_axis="embedding"),
        HelloWorld(),  # <- our "Hello World" layer from Section 1
        ApplyRoPE.from_config(positions_tag="positions", embedding_axis="embedding"),
    ]),
    tags=["positions"],
)
handled_sequential
%%autovisualize
result = handled_sequential((
    pz.nx.ones({"seq": 50, "embedding": 32}),
    pz.nx.arange("seq", 50),
))
pz.show("Final result:", result)

Both of the AddRoPE layers received the same positions argument, without either the Sequential or HelloWorld layers having to worry about passing it through.

Let’s check to make sure our implementation behaves the same as the original Gemma implementation:

%%autovisualize
fake_token_embedding = pz.nx.nmap(jnp.sin)(
    0.333 * (pz.nx.arange("seq", 50) + pz.nx.arange("embedding", 32))
)[{"batch": np.newaxis, "heads": np.newaxis}]
fake_token_embedding
%%autovisualize
gemma.positional_embeddings.apply_rope(
    inputs=fake_token_embedding.unwrap("batch", "seq", "heads", "embedding"),
    positions=jnp.arange(50)[None, :],
    head_dim=32
)[0,:,0,:]
%%autovisualize
handled_rope_layer((
    fake_token_embedding,
    pz.nx.arange("seq", 50),
))[{"batch": 0, "heads": 0}]
chex.assert_trees_all_close(
    gemma.positional_embeddings.apply_rope(
        inputs=fake_token_embedding.unwrap("batch", "seq", "heads", "embedding"),
        positions=jnp.arange(50)[None, :],
        head_dim=32
    ),
    handled_rope_layer((
        fake_token_embedding,
        pz.nx.arange("seq", 50),
    )).unwrap("batch", "seq", "heads", "embedding")
)

Before we move on, it’s worth noting that WithSideInputsFromInputTuple isn’t magic, and it isn’t using any sort of mutable or global state! It’s implementation of __call__ just makes a copy of its child, replaces all of the HandledSideInputRefs it owns with SideInputEffectImpls, and calls it:

@pz.pytree_dataclass
class WithSideInputsFromInputTuple(effect_base.EffectHandler):
  handler_id: effect_base.HandlerId
  body: layer_base.LayerLike
  side_input_tags: tuple[Tag, ...]

  ...

  def __call__(self, argument: tuple[Any, ...]):
    inner_arg = argument[0]
    side_inputs = argument[1:]
    impls = {
        tag: SideInputEffectImpl(_value=val, _handler_id=self.handler_id)
        for tag, val in zip(self.side_input_tags, side_inputs)
    }
    handled_body = (
        selectors.select(self.body)
        .at_instances_of(HandledSideInputRef)
        .where(lambda ref: ref.handler_id == self.handler_id)
        .apply(lambda ref: impls[ref.tag])
    )
    return handled_body(inner_arg)

Conceptually, just as JAX operates in terms of function transformations like jit or vmap, you can think of Penzai as operating in terms of data structure transformations like WithSideInputsFromInputTuple which rewrite your model’s tree structure in a functional way.

If you’d like to learn more about Penzai’s named axis system or data effects system, check out their dedicated tutorial notebooks: “Named Axes in Penzai” and “Data Effects”.

Attention layer#

Next up is the attention layer. The logic of this layer in the Flax implementation is a bit more complex:

class Attention(nn.Module):
  """Attention module."""
  num_heads: int
  num_kv_heads: int
  features: int
  head_dim: int

  @property
  def use_qkv_einsum(self):
    return self.num_kv_heads == self.num_heads

  def setup(self):
    self.attn_vec_einsum = layers.Einsum(shape=(self.num_heads, self.head_dim, self.features))

    if self.use_qkv_einsum:
      self.qkv_einsum = layers.Einsum(shape=(3, self.num_heads, self.features, self.head_dim))
    else:
      self.q_einsum = layers.Einsum(shape=(self.num_heads, self.features, self.head_dim))
      self.kv_einsum = layers.Einsum(shape=(2, self.num_kv_heads, self.features, self.head_dim))

  def __call__(
      self,
      x: jax.Array,
      segment_pos: jax.Array,
      cache: LayerCache | None,
      attn_mask: jax.Array,
  ) -> tuple[LayerCache | None, jax.Array]:
    seq_len = x.shape[1]

    if self.use_qkv_einsum:
      query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
    else:
      query_proj = self.q_einsum('BTD,NDH->BTNH', x)
      key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)

    query_proj = positional_embeddings.apply_rope(
        query_proj, segment_pos, head_dim=self.head_dim,
    )
    query_scaled = query_proj * self.head_dim**-0.5
    key_proj = positional_embeddings.apply_rope(
        key_proj, segment_pos, head_dim=self.head_dim,
    )

    if not self.use_qkv_einsum:
      value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
      key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)

    if cache is not None:
      end_index = cache['end_index'][0]
      slice_indices = (0, end_index % cache['v'].shape[1], 0, 0)
      value_proj = jax.lax.dynamic_update_slice(
          cache['v'], value_proj, slice_indices,
      )
      key_proj = jax.lax.dynamic_update_slice(
          cache['k'], key_proj, slice_indices
      )

    logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
    padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
    probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
    encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
    attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)

    if cache is not None:
      new_cache = {'v': value_proj, 'k': key_proj, 'end_index': cache['end_index'] + seq_len}
    else:
      new_cache = None

    return new_cache, attn_output

This module is doing a lot of different things:

  • Depending on whether cache is provided, it either updates a key-value cache and returns it, or just runs without caching.

    • This particular reference implementation of Gemma does not use Flax’s own mutable state primitives, but some other implementations do, e.g. the one in MaxText.

  • Depending on the number of heads, it either computes queries, keys, and values in one call, or separately computes queries and keys/values.

  • The module takes four arguments, of which one (x) is an input from the previous layer, two (segment_pos and attn_mask) are common across all attention layers, and one (cache) is a state argument.

  • The module returns two arguments, one of which (attn_output) is intended to be passed onward, and one of which (new_cache) needs to be threaded out as an updated state.

In Penzai, by convention every layer does exactly one thing, and the computation is directly mirrored in the model’s structure. Instead of inferring different computation paths at runtime based on the arguments or configuration arguments, idiomatic Penzai models instead use different classes to represent substantially different computation paths.

To create a Penzai version of this attention layer:

  • We’ll specialize our implementation to always compute queries, keys, and values with three separate matrix multiplies.

    • This might have a slight performance penalty, but it will make it easier to patch the model.

  • Instead of repeating the keys and values when num_kv_heads is 1, we’ll simply omit the heads axis and allow it to broadcast automatically. (Note that the only valid values of num_kv_heads in the Flax implementation are num_heads and 1.)

  • We’ll assume there is no key-value cache for now. In Section 3 we’ll define a different adaptation of the attention layer that does have a KV cache. (In idiomatic Penzai models, different runtime behaviors are represented by defining different classes.)

  • We’ll omit the segment_pos argument entirely. Since our positional embeddings already receive the token positions as a side input, we don’t have to thread them through the attention layer.

  • We’ll treat the attn_mask argument as a side input, since it will be shared across all the attention blocks.

  • We’ll refactor the overall computation to decompose it into logically-distinct components like we did for the FeedForward block:

    • The outermost component is in charge of routing the arrays between the query, key, value, attention, and output computations.

    • The innermost components each do a single thing, e.g. running a single tensor contraction, applying an attention mask, or taking a softmax.

For convenience, we’ll also collect the configuration arguments into a dataclass. This GemmaTransformerConfig will just be used to simplify passing arguments during construction of the model; unlike in the Flax version, it won’t actually be part of the resulting model.

Here’s a general implementation which computes queries, keys, and values separately:

@dataclasses.dataclass
class GemmaTransformerConfig:
  # Main configuration:
  num_heads: int
  embedding_dim: int
  projection_dim: int
  single_kv_head: bool
  mlp_hidden_dim: int
  num_decoder_blocks: int
  vocab_size: int
  dtype: jax.typing.DTypeLike

pz_gemma_config = GemmaTransformerConfig(
    num_heads=flax_gemma_config.num_heads,
    embedding_dim=flax_gemma_config.embed_dim,
    projection_dim=flax_gemma_config.embed_dim // flax_gemma_config.num_heads,
    single_kv_head=(flax_gemma_config.num_kv_heads == 1),
    mlp_hidden_dim=flax_gemma_config.hidden_dim,
    num_decoder_blocks=flax_gemma_config.num_layers,
    vocab_size=flax_gemma_config.num_embed,
    dtype=jnp.bfloat16,
)
@pz.pytree_dataclass
class ApplyAttentionMask(pz.Layer):
  mask: pz.de.SideInputEffect[pz.nx.NamedArray]
  masked_out_value: jax.typing.ArrayLike

  @classmethod
  def from_config(
      cls,
      mask_tag: Any,
      masked_out_value: jax.typing.ArrayLike = -2.3819763e38,
  ) -> 'ApplyAttentionMask':
    return cls(
        mask=pz.de.SideInputRequest(tag=mask_tag),
        masked_out_value=masked_out_value,
    )

  def __call__(self, x: pz.nx.NamedArray) -> pz.nx.NamedArray:
    return pz.nx.nmap(jnp.where)(self.mask.ask(), x, self.masked_out_value)
@pz.pytree_dataclass
class GemmaAttention(pz.Layer):
  input_to_query: pz.LayerLike
  input_to_key: pz.LayerLike
  input_to_value: pz.LayerLike
  query_key_to_attn: pz.LayerLike
  attn_value_to_output: pz.LayerLike

  def __call__(self, x: pz.nx.NamedArray) -> pz.nx.NamedArray:
    query = self.input_to_query(x)
    key = self.input_to_key(x)
    value = self.input_to_value(x)
    attn = self.query_key_to_attn((query, key))
    output = self.attn_value_to_output((attn, value))
    return output

  @classmethod
  def from_config(cls, config: GemmaTransformerConfig) -> "GemmaAttention":
    num_heads = config.num_heads
    embedding_dim = config.embedding_dim
    projection_dim = config.projection_dim
    single_kv_head = config.single_kv_head

    if single_kv_head:
      kv_output_axes = {"projection": projection_dim}
      kv_einsum_heads = {}
    else:
      kv_output_axes = {"heads": num_heads, "projection": projection_dim}
      kv_einsum_heads = {"heads": "h"}

    return cls(
        input_to_query=pz.nn.Sequential([
            pz.nn.add_parameter_prefix(
                "query",
                pz.nn.Linear.from_config(
                    input_axes={"embedding": embedding_dim},
                    output_axes={
                        "heads": num_heads,
                        "projection": projection_dim,
                    },
                    dtype=config.dtype,
                ),
            ),
            ApplyRoPE.from_config(
                positions_tag="token_positions",
                embedding_axis="projection",
            ),
            pz.nn.ConstantRescale(by=(projection_dim**-0.5)),
        ]),
        input_to_key=pz.nn.Sequential([
            pz.nn.add_parameter_prefix(
                "key",
                pz.nn.Linear.from_config(
                    input_axes={"embedding": embedding_dim},
                    output_axes=kv_output_axes,
                    dtype=config.dtype,
                ),
            ),
            ApplyRoPE.from_config(
                positions_tag="token_positions",
                embedding_axis="projection",
            ),
        ]),
        input_to_value=pz.nn.Sequential([
            pz.nn.add_parameter_prefix(
                "value",
                pz.nn.Linear.from_config(
                    input_axes={"embedding": embedding_dim},
                    output_axes=kv_output_axes,
                    dtype=config.dtype,
                ),
            ),
        ]),
        query_key_to_attn=pz.nn.Sequential([
            pz.nn.NamedEinsum(
                (
                    {"seq": "tq", "heads": "h", "projection": "p"},
                    {"seq": "tkv", **kv_einsum_heads, "projection": "p"},
                ),
                {"seq": "tq", "heads": "h", "kv_seq": "tkv"}
            ),
            ApplyAttentionMask.from_config(mask_tag="attn_mask"),
            pz.nn.Softmax("kv_seq"),
        ]),
        attn_value_to_output=pz.nn.Sequential([
            pz.nn.NamedEinsum(
                (
                    {"seq": "tq", "heads": "h", "kv_seq": "tkv"},
                    {"seq": "tkv", **kv_einsum_heads, "projection": "p"},
                ),
                {"seq": "tq", "heads": "h", "projection": "p"}
            ),
            pz.nn.add_parameter_prefix(
                "output",
                pz.nn.Linear.from_config(
                    input_axes={
                        "heads": num_heads,
                        "projection": projection_dim,
                    },
                    output_axes={"embedding": embedding_dim},
                    dtype=config.dtype,
                ),
            ),
        ]),
    )

And this is what it looks like when we construct it and migrate over the parameters from the Flax version:

attn_param_mapping = {
    "query.weights": pz.nx.NamedArray.wrap(
        params['transformer']['layer_0']['attn']['q_einsum']['w']
    ).tag("heads", "embedding", "projection"),
    "key.weights": pz.nx.NamedArray.wrap(
        params['transformer']['layer_0']['attn']['kv_einsum']['w'][0,0]
    ).tag("embedding", "projection"),
    "value.weights": pz.nx.NamedArray.wrap(
        params['transformer']['layer_0']['attn']['kv_einsum']['w'][1,0]
    ).tag("embedding", "projection"),
    "output.weights": pz.nx.NamedArray.wrap(
        params['transformer']['layer_0']['attn']['attn_vec_einsum']['w']
    ).tag("heads", "projection", "embedding"),
}
attn_def = GemmaAttention.from_config(pz_gemma_config)
attn_layer = (
    attn_def.select()
    .at_instances_of(pz.nn.UninitializedParameter)
    .apply(lambda param: param.initialize_with_value(
        attn_param_mapping[param.name], strict_dtype=False,
    ))
)
%%autovisualize
attn_layer

Let’s run it to make sure the outputs are the same. We’ll again use WithSideInputFromArg to provide the necessary side inputs.

captured_attn = isolate_submodel.call_and_extract_submodel(
    pz.select(intercepted_gemma).at(
        (lambda root: root.body.submodule_calls[(1, 'layer_0.__call__')].submodule_calls[(1, 'attn.__call__')])
    ),
    example_gemma_wrapped_arg
)
saved_input_embedding = pz.nx.wrap(captured_attn.saved_input.args[0]).tag("batch", "seq", "embedding")
saved_positions = pz.nx.wrap(captured_attn.saved_input.args[1]).tag("batch", "seq")
saved_attn_mask = pz.nx.wrap(captured_attn.saved_input.args[3]).tag("batch", "seq", "kv_seq")
%%autovisualize
wrapped = pz.de.WithSideInputsFromInputTuple.handling(
    attn_layer, tags=["token_positions", "attn_mask"]
)
wrapped((saved_input_embedding, saved_positions, saved_attn_mask))
%%autovisualize
captured_attn.saved_output[1]
chex.assert_trees_all_close(
    wrapped((saved_input_embedding, saved_positions, saved_attn_mask)).unwrap("batch", "seq", "embedding"),
    captured_attn.saved_output[1],
)

In Section 1, we discussed how it is difficult to inspect the attention pattern in the Flax implementation, since the attention computation isn’t exposed. In contrast, it’s trivial to inspect it in our Penzai model, since the computation of the attention weights is its own layer. In fact, here’s all of the intermediate activations throughout the entire attention computation:

%%autovisualize
wrapped_with_intermediates = interleave_intermediates.run_and_interleave_intermediates(
    wrapped,
    (saved_input_embedding, saved_positions, saved_attn_mask)
)
(
    wrapped_with_intermediates.select()
    .at_instances_of(GemmaAttention)
    .at_instances_of(interleave_intermediates.IdentityWithSavedActivations)
    .at_instances_of(pz.nx.NamedArray)
    .show_value()
)

We can pull out the attention mask and visualize it in full resolution:

# Copied from the above printout by clicking the "copy" symbol after the array visualization:
path_fn = (lambda root: root.sublayers[1].body.sublayers[1].query_key_to_attn.sublayers[6].saved_activations[0])
saved_attention_pattern = path_fn(wrapped_with_intermediates)
pz.ts.render_array(
    # Using the copied function to pull out the value of the array we clicked:
    saved_attention_pattern,
    truncate=False,
    # This adds the actual token values to the hover tooltips:
    axis_item_labels={
        "seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
        "kv_seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
    },
    # This overlays the attention mask to hide masked-out locations:
    valid_mask=saved_attn_mask,
)

By default, treescope tries to emphasize detail by adjusting the colormap so that three standard deviations are visible, and outliers are truncated (shown as “+”). But you can extend the colormap range upward by clicking on any “+” cell, and you can also switch to a symmetric-logarithm colormap by clicking on cells with values near zero. Try seeing what attention patterns you can spot! (Clicking on cells also moves the hover tooltip that you clicked below the visualization, so you can copy it later if you want.)

See the “induction heads” notebook for more discussion of looking at attention heads.

Root-mean-squared layer normalization#

Gemma uses RMSNorm (Zhang & Sennrich, 2019) to normalize the inputs and outputs of the attention block. The Flax implementation defines it like this:

class RMSNorm(nn.Module):
  @nn.compact
  def __call__(self, x):
    scale = self.param('scale', nn.initializers.zeros_init(), (x.shape[-1]))
    var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
    normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
    normed_inputs = normed_inputs * (1 + scale)
    return normed_inputs

For our Penzai version, we’ll make a few minor changes:

  • We’ll separate out the normalization logic from the scaling logic, since they don’t depend on each other.

  • We’ll fold the (1 + scale) into scale so that the scaling logic can be expressed with pz.nn.Linear.

  • We’ll re-write the computation to operate over a named axis instead of using the last positional axis.

@pz.pytree_dataclass
class RMSStandardize(pz.Layer):
  across: str | tuple[str, ...] = field(metadata={"pytree_node": False})
  epsilon: float | jax.Array = 1e-6

  @pz.checked_layer_call
  def __call__(self, value: pz.nx.NamedArray) -> pz.nx.NamedArray:
    across = (self.across,) if isinstance(self.across, str) else self.across

    @pz.nx.nmap
    def _rms_standardize(x):
      var = jnp.mean(jnp.square(x))
      return x * jnp.reciprocal(jnp.sqrt(var + self.epsilon))

    return _rms_standardize(value.untag(*across)).tag(*across)

  def input_structure(self) -> Any:
    across = (self.across,) if isinstance(self.across, str) else self.across
    return pz.chk.ArraySpec.floating_named({
        **pz.chk.var("B"),
        **pz.chk.vars_for_axes("across", across),
    })

  def output_structure(self) -> Any:
    return self.input_structure()
@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class RMSLayerNorm(pz.nn.Sequential):

  @classmethod
  def from_config(
      cls,
      across_axes: dict[str, int],
      epsilon: float | jax.Array = 1e-6,
      dtype: jax.typing.DTypeLike = jnp.float32,
  ) -> RMSLayerNorm:
    return cls([
        RMSStandardize(across=tuple(across_axes.keys()), epsilon=epsilon),
        pz.nn.add_parameter_prefix(
            "scale",
            pz.nn.Linear.from_config(
                input_axes={},
                output_axes={},
                parallel_axes=across_axes,
                initializer=pz.nn.constant_initializer(1.0),
                dtype=dtype,
            ),
        ),
    ])

Checking for consistency:

rmsnorm_param_mapping = {
    "scale.weights": pz.nx.NamedArray.wrap(
        1 + params['transformer']['layer_0']['pre_attention_norm']['scale']
    ).tag("embedding"),
}
rmsnorm_def = RMSLayerNorm.from_config(
    {"embedding": flax_gemma_config.embed_dim}
)
rmsnorm_layer = (
    rmsnorm_def.select()
    .at_instances_of(pz.nn.UninitializedParameter)
    .apply(lambda param: param.initialize_with_value(
        rmsnorm_param_mapping[param.name], strict_dtype=False,
    ))
)
%%autovisualize
rmsnorm_layer
captured_rmsnorm = isolate_submodel.call_and_extract_submodel(
    pz.select(intercepted_gemma).at(
        (lambda root: root.body.submodule_calls[(1, 'layer_0.__call__')].submodule_calls[(0, 'pre_attention_norm.__call__')])
    ),
    example_gemma_wrapped_arg
)
%%autovisualize
rmsnorm_layer(
    pz.nx.wrap(captured_rmsnorm.saved_input.args[0])
    .tag("batch", "seq", "embedding")
)
%%autovisualize
captured_rmsnorm.saved_output
chex.assert_trees_all_close(
    rmsnorm_layer(
        pz.nx.wrap(captured_rmsnorm.saved_input.args[0])
        .tag("batch", "seq", "embedding")
    ).unwrap("batch", "seq", "embedding"),
    captured_rmsnorm.saved_output,
)

Transformer block layer#

Now we can put these pieces together to build the main transformer block:

class Struct(nn.Module):
  """Transformer block."""

  num_heads: int
  num_kv_heads: int
  embed_dim: int
  head_dim: int
  hidden_dim: int

  def setup(self):
    self.pre_attention_norm = layers.RMSNorm()
    self.attn = Attention(
        num_heads=self.num_heads,
        features=self.embed_dim,
        head_dim=self.head_dim,
        num_kv_heads=self.num_kv_heads,
    )
    self.pre_ffw_norm = layers.RMSNorm()
    self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)

  def __call__(
      self,
      x: jax.Array,
      segment_pos: jax.Array,
      cache: LayerCache | None,
      attn_mask: jax.Array,
  ) -> tuple[LayerCache | None, jax.Array]:
    inputs_normalized = self.pre_attention_norm(x)
    cache, attn_output = self.attn(
        inputs_normalized,
        segment_pos,
        cache,
        attn_mask,
    )
    attn_output += x
    residual = attn_output
    attn_output = self.pre_ffw_norm(attn_output)
    outputs = self.mlp(attn_output)
    outputs = residual + outputs
    return cache, outputs

This Flax module is mostly a sequence of operations run one after another, but it also has residual connections and explicitly-threaded state. We’ll refactor the residual connections into a Residual combinator, and we’ll drop the state since this version of the implementation doesn’t need it.

As with the FeedForward layer, we’ll express this logic by creating a subclass of pz.nn.Sequential, and adding a class method that builds the sequence of operations:

@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class GemmaTransformerBlock(pz.nn.Sequential):

  @classmethod
  def from_config(cls, config: GemmaTransformerConfig) -> GemmaTransformerBlock:
    return cls(
        sublayers=[
            pz.nn.Residual(
                pz.nn.Sequential([
                    pz.nn.add_parameter_prefix(
                        "pre_attention_norm",
                        RMSLayerNorm.from_config(
                            {"embedding": config.embedding_dim},
                            dtype=config.dtype,
                        ),
                    ),
                    pz.nn.add_parameter_prefix(
                        "attn",
                        GemmaAttention.from_config(config),
                    ),
                ])
            ),
            pz.nn.Residual(
                pz.nn.Sequential([
                    pz.nn.add_parameter_prefix(
                        "pre_ffw_norm",
                        RMSLayerNorm.from_config(
                            {"embedding": config.embedding_dim},
                            dtype=config.dtype,
                        ),
                    ),
                    pz.nn.add_parameter_prefix(
                        "mlp",
                        GemmaFeedForward.from_config(
                            embedding_dim=config.embedding_dim,
                            hidden_dim=config.mlp_hidden_dim,
                            dtype=config.dtype,
                        ),
                    ),
                ])
            ),
        ],
    )

Let’s load it from the checkpoint:

tfblock_param_mapping = {
    **{f"attn.{name}": value for name,value in attn_param_mapping.items()},
    **{f"mlp.{name}": value for name,value in ff_param_mapping.items()},
    "pre_attention_norm.scale.weights": pz.nx.NamedArray.wrap(
        1 + params['transformer']['layer_0']['pre_attention_norm']['scale']
    ).tag("embedding"),
    "pre_ffw_norm.scale.weights": pz.nx.NamedArray.wrap(
        1 + params['transformer']['layer_0']['pre_ffw_norm']['scale']
    ).tag("embedding"),
}
tfblock_def = GemmaTransformerBlock.from_config(pz_gemma_config)
tfblock = (
    tfblock_def.select()
    .at_instances_of(pz.nn.UninitializedParameter)
    .apply(lambda param: param.initialize_with_value(
        tfblock_param_mapping[param.name], strict_dtype=False,
    ))
)
%%autovisualize
tfblock

And make sure it works:

captured_tfblock = isolate_submodel.call_and_extract_submodel(
    pz.select(intercepted_gemma).at(
        (lambda root: root.body.submodule_calls[(1, 'layer_0.__call__')])
    ),
    example_gemma_wrapped_arg
)
%%autovisualize
wrapped = pz.de.WithSideInputsFromInputTuple.handling(
    tfblock, tags=["token_positions", "attn_mask"]
)
saved_input_embedding = (
    pz.nx.wrap(captured_rmsnorm.saved_input.args[0])
    .tag("batch", "seq", "embedding")
)
wrapped((saved_input_embedding, saved_positions, saved_attn_mask))
%%autovisualize
captured_tfblock.saved_output[1]
chex.assert_trees_all_close(
    wrapped(
        (saved_input_embedding, saved_positions, saved_attn_mask)
    ).unwrap("batch", "seq", "embedding"),
    captured_tfblock.saved_output[1],
)

Token embedding layer#

Our last subcomponent is the embedding layer, which is responsible for mapping token IDs to vectors and mapping vectors back to (distributions over) token IDs. In Flax, Gemma’s embedding layer is defined as:

class Embedder(nn.Module):
  """Embedder module."""
  vocab_size: int
  embed_dim: int

  def setup(self):
    self.input_embedding_table = self.param(
        'input_embedding',
        nn.initializers.normal(),
        (self.vocab_size, self.embed_dim),
    )

  def encode(self, x: jax.Array) -> jax.Array:
    x = self.input_embedding_table[(x,)]
    x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
    return x

  def decode(self, x: jax.Array) -> jax.Array:
    return jnp.dot(x, self.input_embedding_table.T)

Unlike the other layers we’ve seen so far, this Flax module has multiple methods. This is a common way to express parameter sharing in Flax, but it’s not a common pattern in Penzai, because it’s hard to represent shared object identity inside a declarative what-you-see-is-what-you-get tree.

To port this to an idiomatic Penzai layer:

  • We’ll split up the single Embedder module into three Penzai classes:

    • The first, EmbeddingTable, will be a simple data structure in charge of owning the parameters and configuration.

    • The second and third will implement the encoding and decoding steps. We do this because, by convention, each Penzai layer does a single thing and takes a single input. (This makes it easier to chain multiple layers together and patch their logic.)

    • Instead of expressing parameter sharing by calling multiple methods on the same Python object, we’ll use the SideInputEffect to share the embedding table between separate encoding and decoding layer objects. The main difference from our previous uses of this effect is that the shared value will be stored in the model itself rather than being provided as an argument.

  • We’ll make the embedding table an explicit dataclass attribute, rather than adding it in a setup method. In Penzai, classes always have exactly the attributes that they are declared to have.

  • We’ll rewrite the configuration attributes and setup logic to live in a classmethod EmbeddingTable.from_config. By convention, initialization logic is usually defined inside a classmethod, to avoid changing the automatic dataclass __init__ method and to make it easier to bypass the initializer if needed.

  • We’ll factor out the jnp.sqrt(self.embed_dim) in encode into it’s own ConstantRescale layer, so that it’s not tightly coupled to the embedding lookup operation. (Note that different transformer implementations differ on where they put this; some implementations scale up the embedding weights and then divide by jnp.sqrt(self.embed_dim) before decoding instead.)

  • We’ll use Penzai’s named axis system to give the vocabulary and embedding dimensions informative names.

This leads us to the following implementation:

@pz.pytree_dataclass
class EmbeddingTable(pz.Struct):  # <- Struct is the base type of most Penzai pytree_dataclasses.
  embeddings: pz.nn.ParameterLike[pz.nx.NamedArray]
  vocabulary_axis: str = dataclasses.field(metadata={"pytree_node": False})

  @classmethod
  def from_config(
      cls,
      vocab_size: int,
      embedding_axes: dict[str, int],
      vocabulary_axis: str = "vocabulary",
      initializer: pz.nn.LinearOperatorWeightInitializer = (
          functools.partial(
              pz.nn.variance_scaling_initializer,
              scale=1.0, mode="fan_out", distribution="normal",
          )
      ),
      dtype: np.typing.DTypeLike = np.float32,
  ) -> EmbeddingTable:
    if vocabulary_axis in embedding_axes:
      raise ValueError(
          f"`vocabulary_axis` {vocabulary_axis} should not appear in"
          f" `embedding_axes` {embedding_axes}"
      )

    return cls(
        embeddings=pz.nn.UninitializedParameter(
            initializer=functools.partial(
                initializer,
                input_axes={},
                output_axes=embedding_axes,
                parallel_axes={vocabulary_axis: vocab_size},
                convolution_spatial_axes={},
                dtype=dtype,
            ),
            name="embeddings",
        ),
        vocabulary_axis=vocabulary_axis,
    )
@pz.pytree_dataclass
class EmbeddingLookup(pz.Layer):
  table: EmbeddingTable

  @pz.checked_layer_call
  def __call__(
      self, token_index: pz.nx.NamedArray
  ) -> pz.nx.NamedArray:
    """Retrieves tokens from the embedding table."""
    return self.table.embeddings.value[{self.table.vocabulary_axis: token_index}]

  def input_structure(self) -> Any:
    return pz.chk.ArraySpec(
        named_shape={**pz.chk.var("B")}, dtype=np.integer
    )

  def output_structure(self) -> Any:
    table_structure = self.table.embeddings.value_structure
    non_lookup_shape = dict(table_structure.named_shape)
    del non_lookup_shape[self.table.vocabulary_axis]
    return pz.chk.ArraySpec(
        named_shape={**pz.chk.var("B"), **non_lookup_shape},
        dtype=table_structure.dtype,
    )
@pz.pytree_dataclass
class EmbeddingDecode(pz.Layer):
  table: EmbeddingTable

  @pz.checked_layer_call
  def __call__(
      self, arg: pz.nx.NamedArray
  ) -> pz.nx.NamedArray:
    """Retrieves tokens from the embedding table."""
    contracting_axes = [
        name for name in self.table.embeddings.value.named_shape.keys()
        if name != self.table.vocabulary_axis
    ]
    return pz.nn.contract(contracting_axes, arg, self.table.embeddings.value)

  def input_structure(self) -> Any:
    table_shape = dict(
        self.table.embeddings.value_structure.named_shape
    )
    del table_shape[self.table.vocabulary_axis]
    return pz.chk.ArraySpec(
        named_shape={**pz.chk.var("B"), **table_shape},
        dtype=np.floating,
    )

  def output_structure(self) -> Any:
    table_shape = self.table.embeddings.value_structure.named_shape
    return pz.chk.ArraySpec(
        named_shape={
            **pz.chk.var("B"),
            self.table.vocabulary_axis: table_shape[self.table.vocabulary_axis]
        },
        dtype=np.floating
    )

Let’s make sure each of these work properly on their own:

emb_table_def = EmbeddingTable.from_config(
    vocab_size=flax_gemma_config.num_embed,
    embedding_axes={"embedding": flax_gemma_config.embed_dim},
    dtype=jnp.bfloat16,
)
emb_table_param_mapping = {
    "embeddings": pz.nx.NamedArray.wrap(
        params['transformer']['embedder']['input_embedding']
    ).tag("vocabulary", "embedding"),
}
emb_table = (
    emb_table_def.select()
    .at_instances_of(pz.nn.UninitializedParameter)
    .apply(lambda param: param.initialize_with_value(
        emb_table_param_mapping[param.name], strict_dtype=False,
    ))
)
%%autovisualize
emb_encoder = pz.nn.Sequential([
    EmbeddingLookup(emb_table),
    pz.nn.ConstantRescale(by=jnp.sqrt(flax_gemma_config.embed_dim).astype(emb_table.embeddings.value.dtype)),
])
emb_encoder
%%autovisualize
emb_encoder(pz.nx.wrap(tokens).tag("seq"))
captured_emb_encoder = isolate_submodel.call_and_extract_submodel(
    pz.select(intercepted_gemma).at(
        (lambda root: root.body.submodule_calls[(0, 'embedder.encode')])
    ),
    example_gemma_wrapped_arg
)
chex.assert_trees_all_close(
    emb_encoder(
        pz.nx.wrap(captured_emb_encoder.saved_input.args[0]).tag("batch", "seq")
    ).unwrap("batch", "seq", "embedding"),
    captured_emb_encoder.submodel(captured_emb_encoder.saved_input),
)
%%autovisualize
emb_decoder = EmbeddingDecode(emb_table)
emb_decoder
captured_emb_decoder = isolate_submodel.call_and_extract_submodel(
    pz.select(intercepted_gemma).at(
        (lambda root: root.body.submodule_calls[(20, 'embedder.decode')])
    ),
    example_gemma_wrapped_arg
)
%%autovisualize
emb_decoder(
    pz.nx.wrap(captured_emb_decoder.saved_input.args[0])
    .tag("batch", "seq", "embedding")
)
chex.assert_trees_all_close(
    emb_decoder(
        pz.nx.wrap(captured_emb_decoder.saved_input.args[0])
        .tag("batch", "seq", "embedding")
    ).unwrap("batch", "seq", "vocabulary"),
    captured_emb_decoder.submodel(captured_emb_decoder.saved_input),
)

What about parameter sharing? Penzai’s parameter utilities assume each parameter in your model PyTree is independent, which means we can’t just put the embedding table in the encoding and decoding steps; this wouldn’t properly tie their weights.

We can express this using the same SideInputEffect we used to share the attention mask and RoPE positions. (In fact, we’ve already briefly seen this when looking at the intercepted Flax model.) Penzai includes a few utilities to help us set this up:

# Temporarily mark the initializer as shareable, so we can find it later.
shareable_emb_table_def = pz.nn.mark_shareable(
    pz.nn.add_parameter_prefix("embedder", emb_table_def)
)
# Use the same definition (with the same parameter name) twice:
encode_then_decode_def = pz.nn.Sequential([
    EmbeddingLookup(shareable_emb_table_def),
    pz.nn.ConstantRescale(by=jnp.sqrt(flax_gemma_config.embed_dim).astype(jnp.bfloat16)),
    HelloWorld(),
    EmbeddingDecode(shareable_emb_table_def),
])
encode_then_decode_def
# "Attach" the shared parameter to a single point in the tree. The shared
# parameter will now be "owned" by a `WithConstantSideInputs` handler.
shared_encode_then_decode_def = pz.nn.attach_shared_parameters(encode_then_decode_def)
shared_encode_then_decode_def
# Initialize it as normal:
qualified_emb_table_param_mapping = {
    "embedder.embeddings": pz.nx.NamedArray.wrap(
        params['transformer']['embedder']['input_embedding']
    ).tag("vocabulary", "embedding"),
}
shared_encode_then_decode = (
    shared_encode_then_decode_def.select()
    .at_instances_of(pz.nn.UninitializedParameter)
    .apply(lambda param: param.initialize_with_value(
        qualified_emb_table_param_mapping[param.name], strict_dtype=False,
    ))
)
shared_encode_then_decode
%%autovisualize
shared_encode_then_decode(pz.nx.wrap(tokens).tag("seq"))

Note that marking parameters as shareable is temporary: these annotations are used by attach_shared_parameters and then forgotten. In the final model tree, we can identify the shared parameters because they use SharedParameterLookup nodes instead of ordinary parameters.

This is an example of another one of Penzai’s design principles: layers should make as minimal assumptions as possible about the implementation of their children. In this case, EmbeddingTable.embeddings is annotated as having type ParameterLike, which means that it can be any PyTree-dataclass type that defines value and value_structure properties. Parameter instances store their values there, but SharedParameterLookup instances instead redirect .value to .ref.ask(). This means the EmbeddingTable doesn’t have to worry about whether its parameter is shared or not.

Putting it together: The top-level Transformer model#

At last, we can assemble the full model, which runs each of these sublayers in the appropriate order! The Flax implementation is pretty straighforward, so we’ll skip looking at it and dive right into the Penzai version.

Since the transformer object is intended to be the top-level module, we’ll have its __call__ take a structure of inputs and handle the unpacking of it. We’ll still make it a pz.Layer, though, so that it composes with the other Penzai utilities that assume __call__ takes one argument.

@pz.pytree_dataclass
class GemmaInputs(pz.Struct):
  tokens: pz.nx.NamedArray
  positions: pz.nx.NamedArray
  attention_mask: pz.nx.NamedArray


@pz.pytree_dataclass
class GemmaTransformer(pz.Layer):
  config: GemmaTransformerConfig = field(metadata={"pytree_node": False})
  body: pz.LayerLike

  def __call__(self, inputs: GemmaInputs) -> pz.nx.NamedArray:
    return self.body((inputs.tokens, inputs.positions, inputs.attention_mask))

  @classmethod
  def from_config(cls, config: GemmaTransformerConfig) -> GemmaTransformer:
    emb_table = pz.nn.mark_shareable(
        pz.nn.add_parameter_prefix(
            "embedder",
            EmbeddingTable.from_config(
                vocab_size=config.vocab_size,
                embedding_axes={"embedding": config.embedding_dim},
                dtype=config.dtype,
            ),
        )
    )
    sublayers = []
    sublayers.extend([
        EmbeddingLookup(emb_table),
        pz.nn.ConstantRescale(
            by=jnp.sqrt(config.embedding_dim).astype(config.dtype)
        ),
    ])
    for i in range(config.num_decoder_blocks):
      sublayers.append(
          pz.nn.add_parameter_prefix(
              f"block_{i}", GemmaTransformerBlock.from_config(config)
          )
      )
    sublayers.extend([
        pz.nn.add_parameter_prefix(
            "final_norm",
            RMSLayerNorm.from_config(
                across_axes={"embedding": config.embedding_dim},
                dtype=config.dtype,
            ),
        ),
        EmbeddingDecode(emb_table),
    ])
    return GemmaTransformer(
        config=config,
        body=pz.de.WithSideInputsFromInputTuple.handling(
            pz.nn.attach_shared_parameters(pz.nn.Sequential(sublayers)),
            tags=["token_positions", "attn_mask"],
        ),
    )

  @classmethod
  def from_pretrained(cls, flax_params: dict[str, Any]) -> GemmaTransformer:
    flax_gemma_config = gemma.transformer.TransformerConfig.from_params(
        flax_params
    )
    config = GemmaTransformerConfig(
        num_heads=flax_gemma_config.num_heads,
        embedding_dim=flax_gemma_config.embed_dim,
        projection_dim=flax_gemma_config.head_dim,
        single_kv_head=(flax_gemma_config.num_kv_heads == 1),
        mlp_hidden_dim=flax_gemma_config.hidden_dim,
        num_decoder_blocks=flax_gemma_config.num_layers,
        vocab_size=flax_gemma_config.num_embed,
        dtype=flax_params["transformer"]["embedder"]["input_embedding"].dtype,
    )
    model_def = cls.from_config(config)
    ftp = flax_params["transformer"]
    parameter_mapping = {
        "embedder.embeddings": pz.nx.NamedArray.wrap(
            ftp["embedder"]["input_embedding"]
        ).tag("vocabulary", "embedding"),
        "final_norm.scale.weights": pz.nx.NamedArray.wrap(
            1 + ftp["final_norm"]["scale"]
        ).tag("embedding"),
    }
    for i in range(config.num_decoder_blocks):
      parameter_mapping.update({
          f"block_{i}.pre_attention_norm.scale.weights": pz.nx.NamedArray.wrap(
              1 + ftp[f"layer_{i}"]["pre_attention_norm"]["scale"]
          ).tag("embedding"),
          f"block_{i}.pre_ffw_norm.scale.weights": pz.nx.NamedArray.wrap(
              1 + ftp[f"layer_{i}"]["pre_ffw_norm"]["scale"]
          ).tag("embedding"),
          f"block_{i}.mlp.gating_linear.weights": pz.nx.NamedArray.wrap(
              ftp[f"layer_{i}"]["mlp"]["gating_einsum"][0]
          ).tag("embedding", "neurons"),
          f"block_{i}.mlp.value_linear.weights": pz.nx.NamedArray.wrap(
              ftp[f"layer_{i}"]["mlp"]["gating_einsum"][1]
          ).tag("embedding", "neurons"),
          f"block_{i}.mlp.out_linear.weights": pz.nx.NamedArray.wrap(
              ftp[f"layer_{i}"]["mlp"]["linear"]
          ).tag("neurons", "embedding"),
          f"block_{i}.attn.output.weights": pz.nx.NamedArray.wrap(
              ftp[f"layer_{i}"]["attn"]["attn_vec_einsum"]["w"]
          ).tag("heads", "projection", "embedding"),
      })
      if config.single_kv_head:
        parameter_mapping.update({
            f"block_{i}.attn.query.weights": pz.nx.NamedArray.wrap(
                ftp[f"layer_{i}"]["attn"]["q_einsum"]["w"]
            ).tag("heads", "embedding", "projection"),
            f"block_{i}.attn.key.weights": pz.nx.NamedArray.wrap(
                ftp[f"layer_{i}"]["attn"]["kv_einsum"]["w"][0, 0]
            ).tag("embedding", "projection"),
            f"block_{i}.attn.value.weights": pz.nx.NamedArray.wrap(
                ftp[f"layer_{i}"]["attn"]["kv_einsum"]["w"][1, 0]
            ).tag("embedding", "projection"),
        })
      else:
        parameter_mapping.update({
            f"block_{i}.attn.query.weights": pz.nx.NamedArray.wrap(
                ftp[f"layer_{i}"]["attn"]["qkv_einsum"]["w"][0]
            ).tag("heads", "embedding", "projection"),
            f"block_{i}.attn.key.weights": pz.nx.NamedArray.wrap(
                ftp[f"layer_{i}"]["attn"]["qkv_einsum"]["w"][1]
            ).tag("heads", "embedding", "projection"),
            f"block_{i}.attn.value.weights": pz.nx.NamedArray.wrap(
                ftp[f"layer_{i}"]["attn"]["qkv_einsum"]["w"][2]
            ).tag("heads", "embedding", "projection"),
        })
    return (
        model_def.select()
        .at_instances_of(pz.nn.UninitializedParameter)
        .apply(
            lambda param: param.initialize_with_value(
                parameter_mapping[param.name],
                strict_dtype=False,
            )
        )
    )
pz_gemma_model = GemmaTransformer.from_pretrained(params)

We can now look inside the structure of the full pretrained model:

%%autovisualize
pz_gemma_model

And run it on the input:

jax.config.update("jax_traceback_filtering", 'off')
%%autovisualize
pz_gemma_output = pz_gemma_model(GemmaInputs(
    tokens=pz.nx.wrap(tokens[None, :]).tag("batch", "seq"),
    positions=pz.nx.wrap(positions).tag("batch", "seq"),
    attention_mask=pz.nx.wrap(attention_mask).tag("batch", "seq", "kv_seq"),
))
pz_gemma_output
chex.assert_trees_all_close(
    flax_gemma_output,
    pz_gemma_output.unwrap("batch", "seq", "vocabulary"),
)

And there we have it! A full port of Gemma to Penzai.

Since all the internals are exposed, we can easily inspect arbitrary parts of this model or insert arbitrary logic. For instance, let’s insert something in the middle of one of the MLPs:

# Copied by clicking in the above visualization:
selector_fn = (lambda root: root.body.body.body.sublayers[13].sublayers[1].delta.sublayers[1].sublayers[1])
# Insert our intermediate-printing "Hello World" layer from Section 1:
patched_model = pz_gemma_model.select().at(selector_fn).insert_before(HelloWorld())
# Look at it:
patched_model.select().at_instances_of(HelloWorld)
%%autovisualize
# Run the patched model:
patched_model(GemmaInputs(
    tokens=pz.nx.wrap(tokens[None, :]).tag("batch", "seq"),
    positions=pz.nx.wrap(positions).tag("batch", "seq"),
    attention_mask=pz.nx.wrap(attention_mask).tag("batch", "seq", "kv_seq"),
))

patched_model is an immutable copy of pz_gemma_model that includes our patching logic. Because of the functional nature of JAX (and Penzai), you never have to worry about rolling back patches or modifying hooks. And JAX automatically shares the array memory between the parameters of the two models.

Section 3: Adding support for KV-Caching#

So far, we’ve focused on the ordinary transformer forward pass, which would be used during training and for log-probability scoring. However, the Flax implementation of Gemma also allows you to do autoregressive sampling using key-value caching. In this section, we’ll show how to add key-value caching to the model we’ve built so far, making it possible to efficiently run autoregressive sampling.

# Reload everything to ensure we have enough memory (for a TPU v2 kernel).
for array in jax.live_arrays():
  array.delete()

pz_gemma_model = GemmaTransformer.from_pretrained(
    gemma.params.nest_params(
        gemma.params.param_remapper(
            checkpointer.restore(ckpt_path, restore_args=restore_args)
        )
    )
)

tokens = jnp.array([vocab.bos_id()] + vocab.EncodeAsIds(example_input))
positions, attention_mask = get_attention_mask_and_positions(tokens[None, :], vocab.pad_id())

Penzai best practice: No conditional branching#

Recall again the implementation of the attention block in the Flax version of Gemma, excerpted below:

class Attention(nn.Module):
  """Attention module."""
  ...

  def __call__(
      self,
      x: jax.Array,
      segment_pos: jax.Array,
      cache: LayerCache | None,
      attn_mask: jax.Array,
  ) -> tuple[LayerCache | None, jax.Array]:
    seq_len = x.shape[1]

    if self.use_qkv_einsum:
      query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
    else:
      query_proj = self.q_einsum('BTD,NDH->BTNH', x)
      key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)

    query_proj = positional_embeddings.apply_rope(
        query_proj, segment_pos, head_dim=self.head_dim,
    )
    query_scaled = query_proj * self.head_dim**-0.5
    key_proj = positional_embeddings.apply_rope(
        key_proj, segment_pos, head_dim=self.head_dim,
    )

    if not self.use_qkv_einsum:
      value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
      key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)

    if cache is not None:
      end_index = cache['end_index'][0]
      slice_indices = (0, end_index % cache['v'].shape[1], 0, 0)
      value_proj = jax.lax.dynamic_update_slice(
          cache['v'], value_proj, slice_indices,
      )
      key_proj = jax.lax.dynamic_update_slice(
          cache['k'], key_proj, slice_indices
      )

    logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
    padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
    probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
    encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
    attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)

    if cache is not None:
      new_cache = {'v': value_proj, 'k': key_proj, 'end_index': cache['end_index'] + seq_len}
    else:
      new_cache = None

    return new_cache, attn_output

This implementation uses conditional branching in two places:

  • Depending on use_qkv_einsum, weights are either combined for queries, keys, and values, or kept separate.

  • Depending on whether the cache argument is passed, the key_proj and value_proj variables are either taken from the projection heads, or combined with the cache. This also determines whether a new cache is returned.

This is fine for the Flax implementation, because the structure of the computation primarily lives in the code. However, this is not a common pattern in Penzai. In general, idiomatic Penzai models should not contain Python conditional branches in their __call__ (except perhaps for shape checking or other assertions). This is a consequence of more general principles:

  • Idiomatic Penzai models should be “what you see is what you get”; it should be obvious what a model is going to do just by looking at its structure. Python conditional branching obscures this.

  • Idiomatic Penzai layers should be easy to patch and manipulate, which is easier if each layer does only one thing. Conditional branches usually mean your layer is doing multiple things depending on its configuration arguments.

  • JAX functions are usually JIT-compiled without Python control flow, and it’s often useful for Penzai models to stay close to their JAX lowerings.

How, then, do you get a single model to do multiple things, such as running in both training and kv-cache-based-sampling modes? Trick question! In general, you shouldn’t have to do this in Penzai. Instead, the Penzai approach is to create a copy of the model that does a different thing.

In fact, we’ve already seen this pattern at work! Every time we inserted logic to capture or print out intermediate values, we’ve been making a copy that does a different thing. Those patched models didn’t use a conditional branch to decide whether or not to output intermediates. Instead, we just built new models that always outputs intermediates, by directly rewriting the model structure.

Penzai’s approach is centered on hot-swapping. Since layers make minimal assumptions about their children, we can implement different behaviors using different layer classes that have the same input and output structures. We can then have each of these implementations do a single thing, and still easily swap between the different implementations.

In the case of our transformer, there are only two layers whose implementation needs to change in KV-cache mode:

  • The attention layer needs to be able to retrieve the keys and values from past tokens, compute attention over them, and update the key-value caches.

  • The top-level transformer wrapper needs to maintain the updated caches and make sure they stay in sync with the inputs.

We’ll assume the caller has correctly set up the positional embedding and attention mask side inputs so that they correctly reflect the current offsets. We’ll also assume the input still has a “seq” axis. This can be of length 1 if we are sampling one at a time, but it can also be longer if we are prefilling the cache with a prompt.

Let’s get started!

Adapting the attention block#

The attention block is where most of the interesting work will happen. Our high-level goal is to build a new version of the GemmaAttention layer that is a drop-in replacement, but runs the KV caching logic instead of the ordinary logic. To handle the mutable KV cache, we’ll use another one of Penzai’s effects, LocalState, which allows us to get and set mutable state variables in a functional way.

We’ll re-use each of the same child blocks from the original GemmaAttention layer. None of them need to change; the only difference is that the inputs to query_key_to_attn and attn_value_to_output children will now have a "kv_seq" axis that is longer than the "seq" axis.

@pz.pytree_dataclass
class GemmaKVCachingAttention(pz.Layer):
  # Same as in GemmaAttention
  input_to_query: pz.LayerLike
  input_to_key: pz.LayerLike
  input_to_value: pz.LayerLike
  query_key_to_attn: pz.LayerLike
  attn_value_to_output: pz.LayerLike

  # New effects:
  kv_cache_end_index: pz.de.SideInputEffect[jax.Array]
  kv_cache: pz.de.LocalStateEffect[tuple[pz.nx.NamedArray, pz.nx.NamedArray]]

  def __call__(self, x: pz.nx.NamedArray) -> pz.nx.NamedArray:
    # Retrieve effectful inputs.
    kvc_end_index = self.kv_cache_end_index.ask()
    key_cache, value_cache = self.kv_cache.get()

    # Compute queries, keys, and values as normal.
    query = self.input_to_query(x)
    key = self.input_to_key(x)
    value = self.input_to_value(x)

    # Update the KV caches.
    new_key_cache = (
        pz.nx.nmap(jax.lax.dynamic_update_slice)(
            key_cache.untag("seq"),
            key.untag("seq"),
            (kvc_end_index,),
        )
        .tag("seq")
    )
    new_value_cache = (
        pz.nx.nmap(jax.lax.dynamic_update_slice)(
            value_cache.untag("seq"),
            value.untag("seq"),
            (kvc_end_index,),
        )
        .tag("seq")
    )
    self.kv_cache.set((new_key_cache, new_value_cache))

    # Run the rest on the updated KV caches.
    attn = self.query_key_to_attn((query, new_key_cache))
    output = self.attn_value_to_output((attn, new_value_cache))
    return output

  @classmethod
  def from_uncached(
      cls,
      original: GemmaAttention,
      cache_len: int,
      cached_axes: dict[str, int],  # <- We need this to initialize the cache.
      cache_dtype: jax.typing.DTypeLike = jnp.float32,
  ) -> GemmaKVCachingAttention:
    """Builds a cached attention from an uncached attention."""

    # Each layer that requests a state variable has to declare an initializer
    # (or a concrete initial state) for it at the time that it's built.
    def kv_cache_initializer():
      empty_cache = pz.nx.zeros(
          {**cached_axes, "seq": cache_len},
          dtype=cache_dtype,
      )
      return (empty_cache, empty_cache)

    return GemmaKVCachingAttention(
        input_to_query=original.input_to_query,
        input_to_key=original.input_to_key,
        input_to_value=original.input_to_value,
        query_key_to_attn=original.query_key_to_attn,
        attn_value_to_output=original.attn_value_to_output,
        kv_cache_end_index=pz.de.SideInputRequest("cache_end_index"),
        kv_cache=pz.de.InitialLocalStateRequest(
            kv_cache_initializer, category="kv_cache",
        ),
    )

Hot-swapping attention blocks#

We’ve set up our caching attention block so that it can instantiate itself as a copy of an initialized non-caching GemmaAttention block, so let’s try grabbing one of the GemmaAttention blocks from the original transformer, and seeing if we can run it in cached decoding mode.

We have to be a bit careful here, because we’ve already handled the SideInput effects in our Penzai Gemma model, and if we just remove a GemmaAttention block from the model we’ll break the link between the handler and its references:

selector_fn = (lambda root: root.body.body.body.sublayers[2].sublayers[0].delta.sublayers[1])
selector_fn(pz_gemma_model)

Treescope shows us this with a “Broken handler refs” message.

One option is to manually go in and replace these with unhandled effects by running something like

(
    pz.select(selector_fn(pz_gemma_model))
    .at_instances_of(pz.de.HandledSideInputRef)
    .where(lambda ref: ref.handler_id == 'WithSideInputsFromInputTuple_bc3f5')
    .apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag)
)

This would let you handle them again normally.

In this case, though, it’s easier to use Penzai’s built-in tools for capturing intermediate values, which lets us re-play the side inputs exactly as they appeared in the original model, and also lets us capture the input and output embeddings for this attention layer in the process.

safely_extracted_attn = isolate_submodel.call_and_extract_submodel(
    pz.select(pz_gemma_model).at(selector_fn),
    GemmaInputs(
        tokens=pz.nx.wrap(tokens[None, :]).tag("batch", "seq"),
        positions=pz.nx.wrap(positions).tag("batch", "seq"),
        attention_mask=pz.nx.wrap(attention_mask).tag("batch", "seq", "kv_seq"),
    )
)
%%autovisualize
safely_extracted_attn.select().at_instances_of(GemmaAttention).at_children().show_value()

As we can see above, the GemmaAttention block has extracted and placed into a new WithSideInputsFromInputTuple handler, and we can see all of the side inputs in the saved_input attribute.

Let’s swap out this captured attention block with our stateful one:

swapped_out_attn = (
    pz.select(safely_extracted_attn.submodel)
    .at_instances_of(GemmaAttention)
    .apply(lambda attn: GemmaKVCachingAttention.from_uncached(
        attn,
        cache_len=58,
        cached_axes={
            "batch": 1,
            "projection":pz_gemma_model.config.projection_dim,
        },
        cache_dtype=jnp.bfloat16,
    ))
)
swapped_out_attn.select().at_instances_of(GemmaKVCachingAttention).at_children().show_value()

To actually run this layer, we now need to handle two more effects (shown as the “Unhandled effects” annotation above): SideInputEffect and LocalStateEffect. We’ve also accidentally broken the positions and embedding

We’ve seen how to handle SideInputEffect, so let’s just get that out of the way:

caching_attn_with_index = pz.de.WithSideInputsFromInputTuple.handling(
    swapped_out_attn,
    tags=["cache_end_index"],
)

To handle the state effect, we need to convert our stateful model into a pure functional one. This means that, instead of mapping x -> y, it will map (x, state_dict) -> (y, state_dict). Since state requests have their own initializers, we can separate our layer into a stateless function and an initial state dict:

%%autovisualize
stateless_caching_attn_w_idx, state_dict = pz.de.handle_local_states(
    caching_attn_with_index, category="kv_cache"
)
print("Stateless model:")
pz.show(stateless_caching_attn_w_idx)
print("State dict:")
pz.show(state_dict)

We now need to call this layer with:

  • embeddings, an attention mask, and token positions (saved from the original model)

  • a cache offset (for our new side input handler)

  • and a dictionary of states (in this case only one)

Let’s try running the model on the first half of the tokens (e.g. pre-filling the cache):

%%autovisualize pz.ts.ArrayAutovisualizer(maximum_size=10_000)
cached_out, state_dict_1 = stateless_caching_attn_w_idx((
    (
        # The original input, which we'll slice along the `tokens` axis.
        (
            pz.select(safely_extracted_attn.saved_input)
            .at_instances_of(pz.nx.NamedArray)
            .apply(lambda arr: arr[{"seq": pz.slice[0:29]}])
        ),
        # The current cache offset (zero since we are at the start.)
        0,
    ),
    # The initial state dictionary.
    state_dict,
))

print("Cached out:")
pz.show(cached_out)
print("New state dict:")
for k, v in state_dict_1.items():
  print(f"{k}:")
  pz.show(pz.ts.render_array(v[0], truncate=False))
  pz.show(pz.ts.render_array(v[1], truncate=False))

Now let’s run it again and watch the states update. We’ll slice the projection axis while we visualize it, so that we can focus on how the values change.

cur_state_dict = state_dict_1
for timestep in range(29, 35):
  step_out, cur_state_dict = stateless_caching_attn_w_idx((
      (
          # The original input, which we'll slice along the `tokens` axis.
          (
              pz.select(safely_extracted_attn.saved_input)
              .at_instances_of(pz.nx.NamedArray)
              .apply(lambda arr: arr[{"seq": pz.slice[timestep:timestep+1]}])
          ),
          # The new cache offset.
          timestep,
      ),
      # The current state dictionary.
      cur_state_dict,
  ))

  key_cache, value_cache = cur_state_dict["WithSideInputsFromInputTuple.body/WithSideInputsFromInputTuple.body/GemmaKVCachingAttention.kv_cache"]

  pz.show(
      "Step", timestep, "keys:",
      pz.ts.render_array(key_cache[{"projection": pz.slice[0:5]}], truncate=False),
      "\nStep", timestep, "values:",
      pz.ts.render_array(value_cache[{"projection": pz.slice[0:5]}], truncate=False),
      "\n--------------------------------"
  )

pz.show(
    "Expected output",
    pz.ts.render_array(key_cache[{"projection": pz.slice[0:5]}], truncate=False),
)

As desired, we’re able to update the keys and values one token at a time. And we obtain the correct slice of the output embeddings:

%%autovisualize
step_out
%%autovisualize
safely_extracted_attn.saved_output[{"seq": pz.slice[34:35]}]

Let’s move on!

del safely_extracted_attn, cur_state_dict, key_cache, value_cache, cached_out, state_dict_1
gc.collect()

Adapting the top-level Transformer model#

We can now wrap this up in a convenient top-level interface by defining a new wrapper class, and providing a constructor method using the same hot-swapping strategy. Because we’re using hot swapping and Penzai’s effect system, we don’t have to thread anything through the transformer blocks; we just swap out the attention layers.

@pz.pytree_dataclass
class GemmaKVCachingState(pz.Struct):
  cache_len: int = field(metadata={"pytree_node": False})
  batch_axes: dict[str, int] = field(metadata={"pytree_node": False})
  kv_caches: dict[str, Any]
  cache_end_index: int | jax.Array


@pz.pytree_dataclass
class GemmaKVCachingInputs(pz.Struct):
  tokens: pz.nx.NamedArray
  positions: pz.nx.NamedArray
  attention_mask: pz.nx.NamedArray
  sampling_state: GemmaKVCachingState


@pz.pytree_dataclass
class GemmaKVCachingTransformer(pz.Layer):
  config: GemmaTransformerConfig = field(metadata={"pytree_node": False})
  body: pz.LayerLike

  def __call__(
      self, inputs: GemmaKVCachingInputs
  ) -> tuple[pz.nx.NamedArray, GemmaKVCachingState]:
    outs, kv_caches = self.body((
        (
            (inputs.tokens, inputs.positions, inputs.attention_mask),
            inputs.sampling_state.cache_end_index,
        ),
        inputs.sampling_state.kv_caches,
    ))
    return outs, GemmaKVCachingState(
        cache_len=inputs.sampling_state.cache_len,
        batch_axes=inputs.sampling_state.batch_axes,
        kv_caches=kv_caches,
        cache_end_index=(
            inputs.sampling_state.cache_end_index
            + inputs.tokens.named_shape["seq"]
        ),
    )

  @classmethod
  def from_uncached(
      cls,
      uncached: GemmaTransformer,
      cache_len: int,
      batch_axes: dict[str, int],  # <- We need this to initialize the cache.
  ) -> tuple[GemmaKVCachingTransformer, GemmaKVCachingState]:
    cached_axes = {
        **batch_axes,
        "projection": uncached.config.projection_dim,
    }
    if not uncached.config.single_kv_head:
      cached_axes["heads"] = uncached.config.num_heads
    caching_body = (
        pz.select(uncached.body)
        .at_instances_of(GemmaAttention)
        .apply(
            lambda attn: GemmaKVCachingAttention.from_uncached(
                attn,
                cache_len=cache_len,
                cached_axes=cached_axes,
                cache_dtype=uncached.config.dtype,
            )
        )
    )
    handled_body, initial_state = pz.de.handle_local_states(
        pz.de.WithSideInputsFromInputTuple.handling(
            caching_body, tags=["cache_end_index"]
        ),
        category="kv_cache",
    )
    inference_model = cls(config=uncached.config, body=handled_body)
    sampling_state = GemmaKVCachingState(
        cache_len=cache_len,
        batch_axes=batch_axes,
        kv_caches=initial_state,
        cache_end_index=0,
    )
    return inference_model, sampling_state
inference_gemma, initial_inference_state = GemmaKVCachingTransformer.from_uncached(
    pz_gemma_model,
    cache_len=58,
    batch_axes={"batch": 1},
)
inference_gemma
%%autovisualize
inference_gemma(GemmaKVCachingInputs(
    tokens=pz.nx.wrap(tokens[None, :]).tag("batch", "seq")[{"seq": pz.slice[0:28]}],
    positions=pz.nx.wrap(positions).tag("batch", "seq")[{"seq": pz.slice[0:28]}],
    attention_mask=pz.nx.wrap(attention_mask).tag("batch", "seq", "kv_seq")[{"seq": pz.slice[0:28]}],
    sampling_state=initial_inference_state,
))

The sampling loop#

Now that we have KV caching, we can use it as a building block for a sampling algorithm. Building a high-performance fully-featured sampler is out of scope for this tutorial, so let’s just implement something simple.

First, a prefilling function, which fills up a KV cache:

def prefill(
    model: GemmaKVCachingTransformer,
    initial_sampling_state: GemmaKVCachingState,
    prompt: pz.nx.NamedArray,
) -> tuple[pz.nx.NamedArray, GemmaKVCachingState]:
  # Token positions are just offsets along the token axis. For simplicity, we're
  # assuming there's no padding to worry about, and that the prompt is a fixed
  # length.
  query_positions = pz.nx.arange("seq", prompt.named_shape["seq"])
  # Tokens can attend to any kv-token position that they are after.
  key_value_positions = pz.nx.arange("kv_seq", initial_sampling_state.cache_len)
  attention_mask = query_positions >= key_value_positions
  # Run prefill:
  out_logits, new_state = model(GemmaKVCachingInputs(
      tokens=prompt,
      positions=query_positions,
      attention_mask=attention_mask,
      sampling_state=initial_sampling_state,
  ))
  # Extract the log probs from the final token, since that determines the probs
  # for the next sampled token.
  final_token_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      out_logits[{"seq": -1}].untag("vocabulary")
  ).tag("vocabulary")
  return final_token_log_probs, new_state
prompt_parts = [
    vocab.EncodeAsIds("Penzai includes a number of general-purpose tools for analyzing JAX neural networks. It also includes a declarative neural-network library"),
    vocab.EncodeAsIds("JAX is Autograd and XLA, brought together for high-performance numerical computing. JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers."),
    vocab.EncodeAsIds("Alice: Let's play 20 questions!\nBob: Sure! Is it something I'd find in a house?"),
    vocab.EncodeAsIds("from __future__ import annotations\nimport collections\nimport contextlib\nimport dataclasses\nimport functools\nimport itertools\nimport typing"),
]
prompt_parts = [ [vocab.bos_id()] + part[:24] for part in prompt_parts]
for part in prompt_parts:
  print(vocab.DecodeIds(part))
  print("-" * 80)
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
prompt = pz.nx.wrap(jnp.array(prompt_parts)).tag("batch", "seq")
prompt
inference_gemma, initial_inference_state = GemmaKVCachingTransformer.from_uncached(
    pz_gemma_model,
    cache_len=100,
    batch_axes={"batch": 4},
)
%%autovisualize
next_log_probs, sampling_state = prefill(inference_gemma, initial_inference_state, prompt)
{"next_log_probs":next_log_probs, "sampling_state":sampling_state}

Now let’s write a function that advances one token at a time:

def advance_one_token(
    model: GemmaKVCachingTransformer,
    state: GemmaKVCachingState,
    next_token: jax.Array,
) -> tuple[pz.nx.NamedArray, GemmaKVCachingState]:
  # Our query position is the current cache offset.
  query_positions = pz.nx.wrap(state.cache_end_index)[{"seq": np.newaxis}]
  # Tokens can attend to any kv-token position that they are after.
  key_value_positions = pz.nx.arange("kv_seq", state.cache_len)
  attention_mask = query_positions >= key_value_positions
  # Run and update just like before, but add a tokens axis:
  out_logits, new_state = model(GemmaKVCachingInputs(
      tokens=next_token[{"seq": np.newaxis}],
      positions=query_positions,
      attention_mask=attention_mask,
      sampling_state=state,
  ))
  # Extract the log probs from this token.
  final_token_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      out_logits.untag("seq").squeeze(0).untag("vocabulary")
  ).tag("vocabulary")
  return final_token_log_probs, new_state

And now we can do a simple iterative loop to run our sampling.

So far we’ve been running things in JAX’s eager mode, but we can easily JIT compile the computation as well. Since every Penzai layer, input, and output is a PyTree, we can just wrap up our model in a Jitted combinator and it all just works. We can even still look inside the model, because Jitted is also just a PyTree:

from penzai.toolshed import jit_wrapper
inference_gemma_jit = jit_wrapper.Jitted(inference_gemma)
inference_gemma_jit
rng = jax.random.key(0)
next_log_probs, sampling_state = prefill(inference_gemma_jit, initial_inference_state, prompt)
outputs = []

while True:
  rng, key = jax.random.split(rng)
  # Split a key across named axes:
  batched_keys = pz.nx.random_split(key, sampling_state.batch_axes)
  next_token = pz.nx.nmap(jax.random.categorical)(
      batched_keys, next_log_probs.untag("vocabulary")
  )
  print([vocab.IdToPiece(int(tok)) for tok in next_token.unwrap("batch").tolist()], end=" ")
  outputs.append(next_token)
  # Are we done?
  if sampling_state.cache_end_index >= sampling_state.cache_len:
    break
  next_log_probs, sampling_state = advance_one_token(inference_gemma_jit, sampling_state, next_token)

stacked_outputs = pz.nx.stack(outputs, "seq")
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
stacked_outputs

Let’s see what Gemma 2B thinks the completions of our prompts could be:

for i in range(4):
  prompt_str = vocab.DecodeIds(prompt.unwrap("batch", "seq")[i, :].tolist())
  completion_str = vocab.DecodeIds(stacked_outputs.unwrap("batch", "seq")[i, :].tolist())
  pz.show(pz.ts.bolded(prompt_str), wrap=True)
  pz.show(pz.ts.with_color(completion_str, "blue"), wrap=True)

Seems like it’s sampling something reasonable! This is a pretty small model and we’re using temperature 1, so these aren’t the highest-quality samples. But it at least means we’ve probably implemented sampling correctly.

Intervening on sampling through patching#

Because of the declarative, functional design of our Gemma reimplementation, we can still look at and intervene on intermediate values even in our stateful JIT-compiled model! This means it’s simple to try out modifications that would require a lot of work to set up in other sampling implementations.

As an example, let’s try knocking out a subset of the attention heads by forcing those heads to only attend to the beginning-of-sequence token. First, we define a simple layer that does the modification we want:

@pz.pytree_dataclass
class KnockOutAttentionHeads(pz.Layer):
  head_mask: pz.nx.NamedArray
  def __call__(self, attn_weights: pz.nx.NamedArray) -> pz.nx.NamedArray:
    knocked_out_attn = pz.nx.wrap(
        jnp.zeros(
            [attn_weights.named_shape["kv_seq"]],
            attn_weights.dtype,
        ).at[0].set(1.0)
    ).tag("kv_seq")
    return pz.nx.nmap(jnp.where)(self.head_mask, attn_weights, knocked_out_attn)
%%autovisualize
KnockOutAttentionHeads(
    head_mask=pz.nx.wrap(jnp.array([1,0,1,0,1,0,1,0])).tag("heads")
)(saved_attention_pattern)

Then we’ll find the places we want to insert it. Let’s knock out all the heads in some of the middle layers:

selection = (
    pz.select(inference_gemma_jit)
    .at_instances_of(GemmaKVCachingAttention)
    .pick_nth_selected((7, 8, 9))
    .at(lambda attn: attn.query_key_to_attn.sublayers[-1])
)
selection

And create a patched copy of our inference model that includes our modification:

patched_inference = (
    selection
    .insert_after(KnockOutAttentionHeads(pz.nx.zeros({"heads": 8})))
)
pz.select(patched_inference).at_instances_of(KnockOutAttentionHeads).show_value()

Now let’s run our sampling loop again, but used our patched copy:

rng = jax.random.key(0)
next_log_probs, sampling_state = prefill(patched_inference, initial_inference_state, prompt)
outputs = []

while True:
  rng, key = jax.random.split(rng)
  # Split a key across named axes:
  batched_keys = pz.nx.random_split(key, sampling_state.batch_axes)
  next_token = pz.nx.nmap(jax.random.categorical)(
      batched_keys, next_log_probs.untag("vocabulary")
  )
  print([vocab.IdToPiece(int(tok)) for tok in next_token.unwrap("batch").tolist()], end=" ")
  outputs.append(next_token)
  # Are we done?
  if sampling_state.cache_end_index >= sampling_state.cache_len:
    break
  next_log_probs, sampling_state = advance_one_token(patched_inference, sampling_state, next_token)

stacked_outputs = pz.nx.stack(outputs, "seq")
for i in range(4):
  prompt_str = vocab.DecodeIds(prompt.unwrap("batch", "seq")[i, :].tolist())
  completion_str = vocab.DecodeIds(stacked_outputs.unwrap("batch", "seq")[i, :].tolist())
  pz.show(pz.ts.bolded(prompt_str), wrap=True)
  pz.show(pz.ts.with_color(completion_str, "blue"), wrap=True)

It’s generally degraded, but still somewhat reasonable.

What if we instead knock out the earliest attention heads?

def go_sample(patched_inference):
  rng = jax.random.key(0)
  next_log_probs, sampling_state = prefill(patched_inference, initial_inference_state, prompt)
  outputs = []

  while True:
    rng, key = jax.random.split(rng)
    batched_keys = pz.nx.random_split(key, sampling_state.batch_axes)
    next_token = pz.nx.nmap(jax.random.categorical)(
        batched_keys, next_log_probs.untag("vocabulary")
    )
    outputs.append(next_token)
    if sampling_state.cache_end_index >= sampling_state.cache_len:
      break
    next_log_probs, sampling_state = advance_one_token(patched_inference, sampling_state, next_token)

  stacked_outputs = pz.nx.stack(outputs, "seq")

  for i in range(4):
    prompt_str = vocab.DecodeIds(prompt.unwrap("batch", "seq")[i, :].tolist())
    completion_str = vocab.DecodeIds(stacked_outputs.unwrap("batch", "seq")[i, :].tolist())
    pz.show(pz.ts.bolded(prompt_str), wrap=True)
    pz.show(pz.ts.with_color(completion_str, "blue"), wrap=True)
patched_inference = (
    pz.select(inference_gemma_jit)
    .at_instances_of(GemmaKVCachingAttention)
    .pick_nth_selected((0, 1, 2))
    .at(lambda attn: attn.query_key_to_attn.sublayers[-1])
    .insert_after(KnockOutAttentionHeads(pz.nx.zeros({"heads": 8})))
)
go_sample(patched_inference)

Knocking out these heads seems to severely affect the generated samples, suggesting that the model is relying heavily on them to parse the prompt.

What about the last heads?

patched_inference = (
    pz.select(inference_gemma_jit)
    .at_instances_of(GemmaKVCachingAttention)
    .pick_nth_selected((15, 16, 17))
    .at(lambda attn: attn.query_key_to_attn.sublayers[-1])
    .insert_after(KnockOutAttentionHeads(pz.nx.zeros({"heads": 8})))
)
go_sample(patched_inference)

Interestingly, knocking out the last attention heads seems to preserve the local coherence of the sample quite well, but we see a bit stronger drifts in the content.

What about knocking out a quarter of the heads through the entire model?

patched_inference = (
    pz.select(inference_gemma_jit)
    .at_instances_of(GemmaKVCachingAttention)
    .at(lambda attn: attn.query_key_to_attn.sublayers[-1])
    .insert_after(KnockOutAttentionHeads(pz.nx.wrap([1,1,1,1,1,1,0,0]).tag("heads")))
)
go_sample(patched_inference)

These particular interventions aren’t controlled enough to let us say anything definitive about what the model is doing. But this example demonstrates how Penzai’s design principles make it possible to quickly patch model behavior, even at inference time. Our simple KV-caching logic immediately supports arbitrary interventions to the sampling process, without us having to explicitly pre-define what changes we wanted to make or thread those changes through the model’s code.

Note: The structure of penzai.example_models.gemma#

In addition to this notebook, Penzai also includes an implementation of Gemma in penzai.example_models.gemma, which is the recommended implementation to use when experimenting with Gemma using Penzai. This implementation is very similar to the implementation given here, with a few minor differences:

  • The classes BranchAndMultiplyTogether, ApplyRoPE, ApplyAttentionMask, RMSStandardize, RMSLayerNorm, EmbeddingTable, EmbeddingLookup, andEmbeddingDecode have been moved into the Penzai standard library pz.nn, since they are not Gemma-specific.

  • The class GemmaAttention has been split into a base Attention combinator and a GemmaAttention subclass, to decouple the basic Attention control flow from the specific initialization logic of GemmaAttention.

  • Similarly the class GemmaKVCachingAttention has been split into a KVCachingAttention class in the standard library and a more specific GemmaKVCachingAttention layer. It includes an extra axis name attribute to make sure the base class does not make assumptions about axis names.

  • The top-level GemmaTransformer class’s from_pretrained method has been modified to not depend on the Flax implementation’s config, and instead loads directly from the flat_params checkpoint.

  • The top-level GemmaTransformer also supports computing activations in float32 even when weights are bfloat16, to allow studying the activations in more detail.

  • The GemmaInputs and GemmaKVCachingInputs input structures have been extended with convenience methods to make them easier to set up for simple cases.

  • The prefill and advance_one_token have been extended to also support padding characters at the end of the prompt, which makes the attention mask and position computations slightly more complex. To track this, they use a SamplingState class that adds a few fields beyond GemmaKVCachingState.

  • Docstrings have been added with additional information on the design of each component.