Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Gemma From Scratch: Inspecting and Re-implementing Gemma with Penzai#

Penzai includes a number of general-purpose tools for analyzing JAX neural networks. It also includes a declarative neural-network library designed to take advantage of those tools. This notebook demonstrates how to apply this tooling to a real-world neural network: the Gemma pretrained transformer model, implemented in Flax.

You might benefit from reading this notebook if any of these apply:

  • You are interested in learning about Penzai’s design principles, especially if you are already familiar with Flax.

  • You want to reverse-engineer a model that is currently written in Flax using Penzai’s tools.

  • You want to implement a model in Penzai, and would like to learn about the best practices for model development.

  • You want to learn more about the Gemma implementation included in Penzai, which is used by the other tutorial notebooks.

This notebook is broken into three main sections plus a setup section.

  • In Section 0, we set up the environment and load the Gemma weights for further analysis.

  • In Section 1, we show how to apply Penzai’s analysis and visualization tooling to the official Flax implementation of Gemma, and how to convert it into a Penzai-compatible form. We also discuss the high-level differences between Flax and Penzai.

  • In Section 2, we break down Gemma’s training/scoring mode into its constituent pieces, and show how to re-implement each of those pieces in idiomatic Penzai style. We also discuss how Penzai enables easy inspection of intermediate values, and use intermediate values to test the implementation of each piece.

  • In Section 3, we apply the same decomposition to the stateful key-value-caching mode, and demonstrate how stateful operations are represented in Penzai models using Penzai’s “data effects” system. We use this to implement a simple JIT-compiled sampler that still supports patching intermediate activations.

Section 0: Setup#

In this section, we’ll start by setting up our environment and loading the Gemma model.

Setting up the environment#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
try:
  import gemma
except ImportError:
  !pip install "gemma @ git+https://www.github.com/google-deepmind/gemma.git"
from __future__ import annotations
from typing import Any

import os
import dataclasses
import traceback
import functools
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import chex

import flax.linen
from jax.experimental import mesh_utils
import sentencepiece as spm
import gemma.params
import gemma.sampler
import gemma.transformer
import penzai
from penzai import pz
import penzai.toolshed.unflaxify
import penzai.toolshed.isolate_submodel

Loading Gemma#

Next we can load the Gemma model, using its official Flax reference implementation. We’ll use Gemma 2B for this notebook.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit https://www.kaggle.com/ and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
ckpt_path = os.path.join(weights_dir, '2b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax.

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
structure = checkpointer.metadata(ckpt_path)

sharding = jax.sharding.SingleDeviceSharding(jax.local_devices()[0])
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array, sharding=sharding
    ),
    structure,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)
params = gemma.params.nest_params(
    gemma.params.param_remapper(flat_params)
)
del flat_params
flax_gemma_config = gemma.transformer.TransformerConfig.from_params(
    params, cache_size=1024
)
flax_gemma = gemma.transformer.Transformer(flax_gemma_config)

Section 1: Analyzing the Flax model with Penzai#

In this section, we’ll give an overview of Penzai’s analysis and visualization tooling and demonstrate how to apply it to existing Flax models.

Looking at the model and its weights#

Penzai ships with a powerful pretty-printer and array visualizer designed to help you quickly navigate through and understand the structure of large trees. If you’ve used Colab or Jupyter before, you may be familiar with printouts that look like this:

print(repr(params))

Penzai provides an interactive alternative, treescope, which allows you to interactively fold and unfold children of deep trees like this. Try clicking on the gray triangle markers to expand or contract subtrees!

pz.show(params)

Let’s turn on penzai.treescope as the default Colab/IPython:

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()

Now everything we return from a Colab cell will be interactively pretty-printed:

params["transformer"]["layer_0"]

penzai.treescope also includes an n-dimensional array visualizer, which can help you understand the shape and content of arrays at a glance. You can hover or click on the cells of the visualization to inspect individual array elements.

Note that, with the truncate=True argument, we automatically cut out the middle elements of each array to keep the visualization a reasonable size. This is similar to how printing out an array produces ... in the middle of array printouts for large arrays.

pz.ts.render_array(params["transformer"]["layer_0"]["attn"]["q_einsum"]["w"], truncate=True)

Since we ran register_autovisualize_magic above, we can also automatically visualize arrays whenever we return something from a Colab cell using the %%autovisualize magic command. Treescope will automatically insert these visualizations inside the rendered tree itself and let you expand them as desired. Try clicking the triangles to look at different weights!

%%autovisualize
params["transformer"]["layer_0"]

We can also use treescope to print out the model itself. However, Flax models don’t show much about themselves when you construct them. In this case, the transformer model is a Python dataclass whose attributes are just it’s configuration:

flax_gemma

(You get something similar if you print it out without Treescope:)

print(flax_gemma)

Unfortunately, you can’t directly see the individual layers inside the Gemma model here, because in Flax those layers aren’t actually built until you call apply on the model and bind it to parameters. But as we will see later, we can use Penzai to get a better look at the internals of the Gemma model.

Looking at model inputs and outputs#

Let’s run the model on some example text. We’ll start by tokenizing it:

example_input = "Penzai includes a number of general-purpose tools for analyzing JAX neural networks. It also includes a declarative neural-network library designed to take advantage of those tools. This notebook demonstrates how to apply this tooling to a real-world neural network: the Gemma pretrained transformer model."
print(example_input)
tokens = jnp.array([vocab.bos_id()] + vocab.EncodeAsIds(example_input))
tokens

We can apply treescope’s array visualizer to tokens too! Discrete data is shown using colors, with stripes used for numbers with a lot of digits. (Fun fact: since sentencepiece tokenizers tend to give lower IDs to more common tokens, more common tokens tend to have simpler-looking visualizations.)

pz.ts.render_array(tokens)

In fact, we can even pass the tokenizer to the autovisualizer, in which case hovering or clicking on array elements will tell you what token each ID is for. Try it below!

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokens

Now let’s call the model on it (adapting the logic from this notebook):

def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = gemma.transformer.build_positions_from_mask(pad_mask)
  attention_mask = gemma.transformer.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask
%%autovisualize
positions, attention_mask = get_attention_mask_and_positions(tokens[None, :], vocab.pad_id())

flax_gemma_output, new_vars = flax_gemma.apply(
    {'params': params['transformer']},
    tokens[None, :],
    positions,
    None, # Attention cache is None.
    attention_mask,
)
assert new_vars is None
flax_gemma_output

This visualization shows up in mostly red, because most of Gemma’s output logits are negative. We can map this to a probability distribution using softmax:

%%autovisualize
jax.nn.softmax(flax_gemma_output)

Let’s find the most-likely prediction at each position, and compare it to the actual tokens:

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
predictions = jnp.argmax(flax_gemma_output, axis=-1)
jnp.stack([predictions[0, :-1], tokens[1:]])
list(zip([vocab.IdToPiece(tok) for tok in predictions[0].tolist()], [vocab.IdToPiece(tok) for tok in tokens[1:].tolist()]))

Looking inside the Flax model with Flax utilities#

How can we figure out what is happening inside Gemma? Of course, we can look at the code to see how it’s implemented, but what if we want to see intermediate activations, or inspect how data flows between Flax modules?

Flax does include a few utilities for this, which are described in the Flax guides and don’t require using Penzai. One option is to use “tabulate” to list out all of the submodule calls:

print(flax_gemma.tabulate(
    jax.random.key(42),
    tokens[None, :],
    positions,
    None, # Attention cache is None.
    attention_mask,
    console_kwargs={"width": 120}
))

Another option is to use capture_intermediates to return intermediate activations:

flax_gemma.apply(
    {'params': params['transformer']},
    tokens[None, :],
    positions,
    None, # Attention cache is None.
    attention_mask,
    capture_intermediates=True,
)

Flax also includes an advanced “intercept_methods” utility which allows you to intercept module calls and apply custom logic.

Flax models v.s. Penzai models#

What if we want to do more complex operations, like looking at the inputs passed to the submodules, changing the output of submodules, or extracting and running submodules individually? This is possible using flax.linen.intercept_methods, but it can be somewhat difficult to reason about. An alternative is to convert the Flax model to a Penzai model, and then use Penzai’s tree-rewriting tools to visualize things. To this end, Penzai provides a utility unflaxify which recursively intercepts every Flax method call and encapsulates it into an equivalent Penzai layer.

Before we show how this works, let’s briefly pause to discuss the differences between Penzai models and Flax models, and the overall differences between the Flax and Penzai conventions and design ideas.

From the “Flax philosophy” documentation, Flax aims to “offer an API familiar to those experienced with Keras/Sonnet/PyTorch” with “an implicit variable management API to save the user from having to manually thread thousands of variables through a complex tree of functions.” To this end, Flax modules are defined as if they own stateful variables and parameters, which they can modify imperatively, and Flax runs logic under-the-hood to transform this stateful view into a functional computation that works with JAX. Flax modules generally look something like this:

Initializer = jax.nn.initializers.Initializer

class SimpleFlaxDense(flax.linen.Module):
  features: int
  kernel_init: Initializer = flax.linen.initializers.lecun_normal()
  bias_init: Initializer = flax.linen.initializers.zeros_init()

  @flax.linen.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # Shape info.
    y = jnp.dot(inputs, kernel)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

class SimpleFlaxMLP(flax.linen.Module):
  out_dims: int

  @flax.linen.compact
  def __call__(self, x):
    x = SimpleFlaxDense(128)(x)
    x = flax.linen.relu(x)
    x = SimpleFlaxDense(self.out_dims)(x)
    return x
SimpleFlaxMLP(out_dims=32)
flax_mlp_params = SimpleFlaxMLP(out_dims=32).init(jax.random.key(10), jnp.ones((4, 8)))
flax_mlp_params