Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Open in Colab Open in Kaggle

Induction Heads in Gemma 7B#

One of Penzai’s primary goals is to support interpretability research on state-of-the-art models. In this notebook, we’ll use Penzai to try to find and intervene on induction heads (Elhage et al. 2021, Olsson et al. 2022) in the Gemma 7B open-weights model. We’ll be focusing on exploratory analysis and on Penzai’s tooling rather than on rigor; the goal is to show how you can use Penzai to quickly prototype ideas and generate hypotheses about network behavior (not to perfectly measure the presence of induction heads or exactly reproduce previous results).

Along the way, we’ll discuss:

  • How to use JAX’s sharding support to automatically shard the model over a cluster of TPUs,

  • How to use Penzai’s pretty-printer (Treescope) to quickly look at model weights and activations,

  • How to extract intermediate values and intermediate subcomputations from a larger model for detailed analysis, using either Penzai’s manual patching tool or using Penzai’s data-effect system,

  • How to use Penzai’s named axis library to identify the characteristic patterns of induction heads,

  • And how to patch the Gemma model by intervening on intermediate subcomputations (in this case, the attention weights),

Let’s get started!

Note: This version of this tutorial uses the 7-billion parameter Gemma model, which requires an accelerator with at least 24GB+ of RAM. (Colab “TPU v2” or Kaggle TPU kernels should work.) For a version with a smaller memory footprint, see the “Induction Heads in Gemma 2B” tutorial, which covers the same material.

Setting up and loading the model#

We’ll start by setting up the environment and loading the Gemma 7B model.


To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations

import os
import dataclasses
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
from jax.experimental import mesh_utils
import sentencepiece as spm
import penzai
from penzai import pz

from penzai.example_models import gemma

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default Colab pretty-printer. We’ll also turn on automatic visualization of JAX and Numpy arrays. This will make it easy to look at our models and their outputs.


Loading Gemma#

Next we’ll load the weights from the Gemma checkpoint. We’ll use the 7B checkpoint for this tutorial.

This notebook should work in any kernel with enough memory to load the 7B model, which includes Colab’s “TPU v2” and “A100” kernels and Kaggle notebook TPU kernels. You can also run this using your own local GPU IPython runtime, either connected to Colab or a different IPython frontend.

If you don’t have access to an accelerator with enough memory, you can open the “Induction Heads in Gemma 2B” tutorial instead, which walks through the analysis for the smaller model and should work on a Colab T4 GPU kernel. (Both tutorials cover the same material, but the locations of the induction heads and some aspects of the model predictions differ between the variants!)

When loading the arrays, we’ll shard them over their last positional axis, which ensures that they fit in memory on the “TPU v2” kernel. JAX and the Orbax checkpointer automatically take care of partitioning the arrays across the devices and exposing a uniform interface to the sharded arrays. In fact, most operations on partitioned arrays “just work” without having to do anything special. (You can read more about JAX’s automatic distributed arrays on this JAX documentation page.)

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
  from google.colab import userdata
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
except ImportError:

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax:

vocab = spm.SentencePieceProcessor()
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

Let’s take a look! Since we’ve registered Treescope as the default pretty-printer and turned on array visualization, we can just output the arrays from Colab and see a rich visualization of their values.

Try clicking to explore the structure of the arrays below!

(Note: It may take a while for the array summaries to load the first time, because JAX has to compile the summarization code. You can still look at array shapes before they finish, and it should be faster to run the second time.)