Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Induction Heads in Gemma 7B#

One of Penzai’s primary goals is to support interpretability research on state-of-the-art models. In this notebook, we’ll use Penzai to try to find and intervene on induction heads (Elhage et al. 2021, Olsson et al. 2022) in the Gemma 7B open-weights model. We’ll be focusing on exploratory analysis and on Penzai’s tooling rather than on rigor; the goal is to show how you can use Penzai to quickly prototype ideas and generate hypotheses about network behavior (not to perfectly measure the presence of induction heads or exactly reproduce previous results).

Along the way, we’ll discuss:

  • How to use JAX’s sharding support to automatically shard the model over a cluster of TPUs,

  • How to use Penzai’s pretty-printer (Treescope) to quickly look at model weights and activations,

  • How to extract intermediate values and intermediate subcomputations from a larger model for detailed analysis, using either Penzai’s manual patching tool pz.select or using Penzai’s data-effect system,

  • How to use Penzai’s named axis library to identify the characteristic patterns of induction heads,

  • And how to patch the Gemma model by intervening on intermediate subcomputations (in this case, the attention weights),

Let’s get started!

Note: This version of this tutorial uses the 7-billion parameter Gemma model, which requires an accelerator with at least 24GB+ of RAM. (Colab “TPU v2” or Kaggle TPU kernels should work.) For a version with a smaller memory footprint, see the “Induction Heads in Gemma 2B” tutorial, which covers the same material.

Setting up and loading the model#

We’ll start by setting up the environment and loading the Gemma 7B model.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations

import os
import dataclasses
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
from jax.experimental import mesh_utils
import sentencepiece as spm
import penzai
from penzai import pz

from penzai.example_models import gemma

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default Colab pretty-printer. We’ll also turn on automatic visualization of JAX and Numpy arrays. This will make it easy to look at our models and their outputs.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())
pz.ts.register_context_manager_magic()

Loading Gemma#

Next we’ll load the weights from the Gemma checkpoint. We’ll use the 7B checkpoint for this tutorial.

This notebook should work in any kernel with enough memory to load the 7B model, which includes Colab’s “TPU v2” and “A100” kernels and Kaggle notebook TPU kernels. You can also run this using your own local GPU IPython runtime, either connected to Colab or a different IPython frontend.

If you don’t have access to an accelerator with enough memory, you can open the “Induction Heads in Gemma 2B” tutorial instead, which walks through the analysis for the smaller model and should work on a Colab T4 GPU kernel. (Both tutorials cover the same material, but the locations of the induction heads and some aspects of the model predictions differ between the variants!)

When loading the arrays, we’ll shard them over their last positional axis, which ensures that they fit in memory on the “TPU v2” kernel. JAX and the Orbax checkpointer automatically take care of partitioning the arrays across the devices and exposing a uniform interface to the sharded arrays. In fact, most operations on partitioned arrays “just work” without having to do anything special. (You can read more about JAX’s automatic distributed arrays on this JAX documentation page.)

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit https://www.kaggle.com/ and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax:

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

Let’s take a look! Since we’ve registered Treescope as the default pretty-printer and turned on array visualization, we can just output the arrays from Colab and see a rich visualization of their values.

Try clicking to explore the structure of the arrays below!

(Note: It may take a while for the array summaries to load the first time, because JAX has to compile the summarization code. You can still look at array shapes before they finish, and it should be faster to run the second time.)

flat_params

The next step is to build the Gemma model using these weights. Since we’re interested in studying model activations, we’ll configure it to compute the activations in float32 precision, even though the weight are stored in bfloat16 precision.

model = gemma.model_core.GemmaTransformer.from_pretrained(
    flat_params,
    upcast_activations_to_float32=True
)

Let’s look at it!

model

Penzai models are designed to reveal as much information as possible when pretty printed. Try clicking the triangles to expand different layers, and look at the structure of the computation and of the parameters!

We’ll be taking a closer look at the attention layers later in this notebook. (If you’d like to learn more about the overall structure of this model, you can read the separate “Gemma from Scratch” tutorial.)

del flat_params
gc.collect()

Looking at outputs#

Before we can look at the induction heads, we’ll need to have some input to run it on. Taking inspiration from Olsson et al. (2021), let’s try running the model on a repeated sequence of random tokens.

The Gemma model is trained on natural text, so if we pick token IDs uniformly at random, it tends to get confused (since that’s not a natural distribution over tokens). Instead, we’ll run it on random numeric digits, which are likely to have shown up somewhere in the training data:

example_text = (
    "01976954310149754605"
    + "01976954310149754605"
)

The Gemma tokenizer tokenizes each digit separately:

tokens = jnp.array([vocab.bos_id()] + vocab.EncodeAsIds(example_text))
tokens

Treescope visualizes integer arrays using “digitbox” patterns, where each base-10 digit of the integer (in this case, of the token ID) is shown as a colored stripe. This can be used to visualize patterns across multiple examples. We can make the correspondence more visible by telling the autovisualizer about our tokenizer:

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokens

Try hovering or clicking on the above visualization; you should see the token ID and token string for each token.

Next, let’s tag it with axis names. Penzai includes a lightweight named-axis system, which allows you to associate names with arbitrary axes. There’s a separate tutorial about how to use the named axis system, but the short version is:

  • Named arrays are represented using Python pz.nx.NamedArray dataclass, which is just a combination of an array and a sequence of axis names.

  • It’s OK for only a subset of axes to have names. You can bind positional axes to names using .tag(name1, name2, ...), and unbind names back into positional axes using .untag(name1, name2, ...).

  • You can run ordinary JAX functions using pz.nx.nmap (e.g. pz.nx.nmap(jax.nn.softmax)(array, axis=0)). JAX functions only see the positional axes, and automatically vectorize over named axes (using jax.vmap under the hood), so if you want to run it over a named axis you need to unbind the name first (e.g. pz.nx.nmap(jax.nn.softmax)(array.untag("vocab"), axis=0).tag("vocab")). You can also use array instance methods like array.sum(), but they again only operate over positional axes.

We can wrap our token array like this:

token_seq = pz.nx.wrap(tokens).tag("seq")
token_seq

Treescope knows how to visualize NamedArrays, so the “axis 0” annotation now shows as “seq”.

There are also some more complicated utilities for visualizing named arrays of tokens in particular:

from penzai.toolshed import token_visualization
token_visualization.show_token_array(token_seq, vocab)

We can now run our model by simply calling it with an appropriately-constructed input. We’ll need to build an attention mask and an array of positions (for the positional embeddings), but we can use a helper function for this:

%%with pz.ts.using_expansion_strategy(max_height=50)

example_arg = gemma.model_core.GemmaInputs.from_basic_segments(token_seq)
example_arg

Now we can call our model and look at the output log-probabilities:

logits = model(example_arg)
# Map softmax over the vocabulary
log_probs = pz.nx.nmap(jax.nn.log_softmax)(logits.untag("vocabulary")).tag("vocabulary")
log_probs

We’re most interested in the log-probabilities of the correct token, so let’s look at those. To do this, we’ll first slice off the first or last tokens, so that we align the previous step’s prediction with the next step’s ground truth:

# Indexing with a dictionary indexes the named axes; pz.slice helps slice them.
sliced_preds = log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = token_seq[{"seq": pz.slice[1:]}]

Then we’ll index into the vocabulary axis using the correct tokens:

log_prob_of_correct_next = sliced_preds[{"vocabulary": correct_next_token}]
log_prob_of_correct_next

With Penzai’s named axis system, axes with the same name always broadcast together. In this case, it matched up the “seq” axis in both arrays, which is indeed what we wanted.

Note that the dictionary-style indexing of named arrays is syntactic sugar around lower-level tagging and indexing operations. We could have gotten the same result with

# Unbind the names, slice the positional array, rebind the names if needed.
sliced_preds = log_probs.untag("seq")[:-1].tag("seq")
correct_next_token = token_seq.untag("seq")[1:].tag("seq")
sliced_preds.untag("vocabulary")[correct_next_token]

Penzai provides some utilities for visualizing token scores:

# Log probs (redder is smaller)
token_visualization.show_token_scores(correct_next_token, log_prob_of_correct_next, vocab)
# Probabilities (bluer is larger)
token_visualization.show_token_scores(correct_next_token, pz.nx.nmap(jnp.exp)(log_prob_of_correct_next), vocab)

(Try hovering your mouse over the boxes!)

We see that the model settles to around 5-10% accuracy for the first repetition of the sequence, then quickly ramps up to about 80-100% accuracy after seeing the first four digits of the repetition. Let’s try to figure out how!

Looking at attention patterns#

The first step to understanding the attention patterns used by Gemma is to extract what those patterns actually are. Penzai has a few different tools we can use for this. We’ll start by showing how you could implement this logic yourself using Penzai’s tree-rewriting utility pz.select, and then discuss higher-level wrappers that make this particular use case easier.

Injecting logic with pz.select#

Penzai models are designed to be easy to patch in an interactive setting. In particular, it’s easy to create copies of your model that include new bits of logic.

Penzai’s primary tool for manipulating the structure of Python objects is pz.select. This can be used to identify parts of an object you want to change, and then make changes to it. Under the hood, it’s built on top of jax.tree_util, so it works on any type that’s been registered with JAX. You can use it like this:

my_list = [1, 2, "Hello ", 4]

# Append something to all strings:
my_patched_list = (
    pz.select(my_list).at_instances_of(str).apply(lambda s: s + "World!")
)

# my_list isn't modified:
pz.show("my_list:", my_list)

# But my_patched_list includes our change:
pz.show("my_patched_list:", my_patched_list)

pz.select produces a Selection object that tracks a particular part (or parts) of a larger structure. We can print them out to see what part is selected:

pz.select(my_list).at_instances_of(str)

There are a lot of different ways to build and manipulate Selections. Another useful thing you can do is pass a function that picks out a particular subtree of your tree:

pz.select(my_list).at(lambda root: root[1])

(Aside: You may have noticed the little grey “copy” buttons next to each line of a treescope rendering. If you click one of those, it will copy a function that picks out that subtree, which can be useful either for extracting the value or for passing to pz.select(...).at(...).)

See the separate “Selectors” tutorial for details about the operations pz.select supports! For now, we’ll discuss the features we need as we use them.

Recall again our model object. We can use pz.select to identify a specific attention block:

%%autovisualize None

selected = (
    pz.select(model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .pick_nth_selected(1)
    .at_instances_of(pz.nn.Softmax)
)
selected

In Penzai models, the pretty-printed representation isn’t just a summary of what’s inside your model, it’s actually a complete specification of all of the steps that occur when your model runs. So, if we insert something new into our model using pz.select, the resulting model will run the logic we inserted along with the rest of its operations!

For instance, let’s define a simple layer that shows its intermediate value:

@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.Layer):  # <- pz.Layer is the base class of Penzai layers.
  def __call__(self, intermediate_value):
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value

By convention:

  • Every layer is a Python dataclass and a JAX pytree. This makes it easy for JAX and Penzai tools to understand it, and makes it safe to pass across JAX transformations.

  • Most model components are subclasses of the abstract base class pz.Layer, which means that they must define __call__. To allow composing together multiple layers, __call__ always takes a single positional argument as input (but in some cases this positional argument can be a structure like a tuple or list).

We can instantiate our layer:

DisplayIntermediateValue()

And call it:

layer = DisplayIntermediateValue()
output = layer(123)
pz.show("Final:", output)

Now let’s try inserting our new layer into the Gemma model.

%%autovisualize None

# Make a patched copy of our model:
patched = (
    pz.select(model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .pick_nth_selected(1)
    .at_instances_of(pz.nn.Softmax)
    .insert_after(DisplayIntermediateValue())
)

# Find the thing we inserted into it:
pz.select(patched).at_instances_of(DisplayIntermediateValue)

You might notice that the new DisplayIntermediateValue() layer is inside a layer called Sequential. Sequential is a layer combinator that runs all of it’s children in sequence, passing the output of one layer to the input of the next. So inserting DisplayIntermediateValue here means it will receive the output of the softmax as its own input, and it’s own output will be returned as the final answer of query_key_to_attn. In fact, GemmaAttention is also a layer combinator; it simply combines the outputs from query_key_to_attn and input_to_value and passes them to the attn_value_to_output layer.

By convention, idiomatic Penzai models generally express as much as possible in terms of these layer combinators, and defer the actual logic to small primitive layers like Softmax or ApplyAttentionMask. Structuring models this way naturally exposes all of the parts of the model that we might want to inspect or modify, making it easy to insert new logic like we’ve just done.

If we call our new patched copy of the model, we get to see a summary of the attention weights, pretty-printed by the default renderer:

patched(example_arg)

Looks like we successfully intercepted the attention weights! However, since the full matrix is quite big, treescope has automatically truncated it to save space and to avoid sending huge amounts of data. This means it’s only showing ten of the 16 heads, and only the first and last eight tokens of the query and key/value axes, with dark shading indicating missing values.

(Sidenote: We haven’t actually modified the original model. Calling it still runs the original logic and thus doesn’t print out the attention matrices:

model(example_arg)

Nevertheless, the two models are sharing the same accelerator memory, because pz.select re-uses parts of the structure that haven’t changed.)

Let’s try instead actually pulling the value out. To do that, we’ll write a layer that stores it’s activation inside one of its own attributes when it runs.

Because pz.Layer instances and pz.pytree_dataclass nodes are immutable, we can’t do this with a pz.Layer subclass and we can’t use pz.pytree_dataclass. Instead, we’ll make our new “layer” just an ordinary Python class that implements __call__. We can also define a helper function that is in charge of pulling the value out of it, so that we don’t have to worry about tracking mutable state across multiple notebook cells.

class SaveIntermediateValue:
  def __init__(self):
    self.intermediate_value = None

  def __call__(self, intermediate_value):
    self.intermediate_value = intermediate_value
    return intermediate_value
def get_activation_after(model_with_selected_part, model_input):
  # Create our mutable object:
  saver = SaveIntermediateValue()
  # Insert it into a temporary copy of the model:
  mutably_patched = model_with_selected_part.insert_after(saver)
  # Run the patched model:
  output = mutably_patched(model_input)
  # We're done with the mutable copy of our model now, so we can just let it
  # go out of scope.
  # Grab the stored value out of `saver`:
  return {"output": output, "intermediate_value": saver.intermediate_value}
results = get_activation_after((
    pz.select(model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .pick_nth_selected(1)
    .at_instances_of(pz.nn.Softmax)
), example_arg)
results

We are now free to inspect the attention patterns in more detail. We can use pz.ts.render_array, which produces the same type of figure as the default autovisualizer, but gives us control over how it renders things:

pz.ts.render_array(
    results["intermediate_value"],
    truncate=False,  # <- False is the default value, but it's True in the autovisualizer
    # This adds the actual token values to the hover tooltips:
    axis_item_labels={
        "seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
        "kv_seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
    },
    # Put query position on the Y axis.
    rows=["seq"]
)

We can see a few different behaviors here. Head 2 appears to be attending to “two tokens back”. Head 3 seems to attend to every previous occurence of each token. And many of the other layers seem to be doing something “fuzzier”, falling back to attending to the beginning-of-sequence token.

JAX note: Since mutably_patched contains a mutable non-pytree part (our SaveIntermediateValue block), if we tried to run mutably_patched(model_input) under jax.jit directly, we’d get a JAX error. However, the inputs and outputs of get_activation_after don’t include the SaveIntermediateValue object, so it’s still fine to jax.jit it:

results = jax.jit(get_activation_after)(selected, example_arg)
results

A general guideline when working with JAX and Penzai models is to keep mutable state local and easy to reason about, and make sure that the “resting” state of the model is immutable.

del results
gc.collect()

Using Penzai’s effect system#

In our above solution, we temporarily inserted some mutable state into a copy of the model, ran it, then extracted the result and discarded the mutable copy. This is a useful and powerful pattern, so Penzai has explicit support for it in the form of the data effects system.

At a high level, the data effects system works like this:

  • You declare what effects you want your model to have by adding “request” markers to its layers. Request markers are immutable, so it’s safe to hold them in the model. (This is kind of like how we passed a selection to get_activation_after to identify which parts we wanted to target.)

  • “Handler” objects detect request markers and take ownership of them, replacing them with “ref” markers. Refs are also immutable, so it’s safe to hold them.

  • When your model actually runs, the handler objects temporarily swap in mutable Python objects, which your layers can access and call methods on. (This is basically what get_activation_after did above.)

  • The handler extracts any information it needs and returns it, and the temporary mutable copy of the model is discarded, leaving only the original immutable version (with the ref objects in it).

You can read more about the data effects system in the separate data effects tutorial. For now, let’s see how it can help us extract attention scores!

The particular effect we’re interested in is the SideOutput effect, which automates the process of extracting intermediate values from models. We can write a layer type that reports its values to a side output:

@pz.pytree_dataclass
class TellIntermediate(pz.Layer):
  # side_out is an attribute where the effect implementation will be inserted.
  # For now, we initialize it to a request, which is immutable.
  side_out: pz.de.SideOutputEffect = pz.de.SideOutputRequest("intermediate_value")

  def __call__(self, intermediate_value):
    # When called, assume `side_out` has been replaced with an implementation
    # that we can report values to:
    self.side_out.tell(intermediate_value)
    return intermediate_value

(This helper layer is also defined in pz.de.TellIntermediate for convenience.)

Then we can insert them into a copy of our model. Let’s insert them after the softmax of every attention block:

%%autovisualize None

model_with_requests = (
    pz.select(model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at_instances_of(pz.nn.Softmax)
    .insert_after(TellIntermediate())
)

# Show one of them, as an example:
pz.select(model_with_requests).at_instances_of(TellIntermediate).pick_nth_selected(0).show_value()

Next we “handle” them with the CollectingSideOutputs handler:

%%autovisualize None

model_with_collector = pz.de.CollectingSideOutputs.handling(
    pz.select(model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at_instances_of(pz.nn.Softmax)
    .insert_after(TellIntermediate())
)

pz.select(model_with_collector).at_instances_of(TellIntermediate).pick_nth_selected(0).show_value()

Notice two changes:

  • The outermost wrapper is now a CollectingSideOutputs layer, with a unique handler ID.

  • Inside the GemmaAttention blocks, the side_out attribute of the TellIntermediate layer we added has been replaced with a HandledSideOutputRef referencing that handler ID.

Now, when we call our patched model, this handler will find all of the HandledSideOutputRef objects with the matching handler ID, and replace them with mutable Python objects. Then, once it’s done running, it will package up all the return values in a list:

collected_out = model_with_collector(example_arg)
collected_out

The second output here is a list of all of the calls to .tell inside the layer we added. Each entry is a tuple (layer_keypath, tag, value), where the layer_keypath tells us what part of the model each output came from.

Let’s stack all the attention masks together:

all_attentions = pz.nx.stack(
    [intermediate.value for intermediate in collected_out[1]],
    "blocks",
)
# ^ shorthand for nmap(jnp.stack)([...]).tag("blocks")
del collected_out

And let’s look at them!

Treescope tip for viewing large arrays: Holding Alt and scrolling will zoom in or out, and Shift+scrollwheel scrolls you along the X-axis instead of the Y-axis.

pz.ts.render_array(
    all_attentions,
    # Annotate the sequence axes with the token names:
    axis_item_labels={
        "seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
        "kv_seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
    },
    # Customize the row/column assignments
    rows=["seq", "blocks"], columns=["kv_seq", "heads"],
    # Overlay the causal mask:
    valid_mask=example_arg.attention_mask,
)

Can you spot any interesting behavior? A few thinks you can look for:

  • Heads that attend a fixed offset back from the current token

  • Heads that attend to previous occurrences of the same token

  • Heads that seem to find a single token and attend to it throughout the sequence

  • Heads that have different behavior on the second repetition of the first

    • In particular, can you find any possible induction heads: Heads that attend to repeated tokens but shifted by one, so that they are attending to the token that will come next?

Identifying induction heads#

Let’s focus in our search on induction heads. Briefly summarizing the definition of an induction head from Elhage et al. (2021) and Olsson et al. (2022), an induction head:

  • Finds a previous occurence of the current token (or, at a more abstract level, a previous repetition of the current sequence content)

  • Attends to the token after the previous occurence

  • Increases the likelihood of repeating the token it attended to (which, if the pattern continues, would also be the token after the current one)

Our digit sequence is of length 20. So, we’re looking for attention heads that attend 19 tokens back, e.g. when given the 25th token as input, they’d find the previous repetition (token 5) and attend to the next one (token 6).

Let’s try swapping the facet order of the above plot. On the outside, we’ll show the two sequences, and in each inner facet, we’ll show the matrix of blocks and heads. Each inner 2D rectangle will then be a “fingerprint” of all of the different attention heads that attended between the two given tokens.

pz.ts.render_array(
    all_attentions,
    axis_item_labels={
        "seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
        "kv_seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
    },
    # Swap row and column order
    rows=["blocks", "seq"], columns=["heads", "kv_seq"],
    valid_mask=example_arg.attention_mask,
)

Let’s focus on the part of the attention matrix where the query is in the second repetition, and the keys are in the first (e.g. the bottom left quadrant of the above plot). We’ll slice the matrices to just look at the first ten tokens:

pz.ts.render_array(
    all_attentions[{"seq": pz.slice[20:30], "kv_seq": pz.slice[0:10]}],
    axis_item_labels={
        "seq": [repr(vocab.IdToPiece(int(t))) for t in tokens[20:30]],
        "kv_seq": [repr(vocab.IdToPiece(int(t))) for t in tokens[0:10]],
    },
    rows=["blocks", "seq"], columns=["heads", "kv_seq"],
)

We’re looking for attention heads that attend “one in the future” relative to the repetition. Since “seq” is on the (outer) Y axis and “kv_seq” is on the (outer) X axis, that means we’re looking for attention heads that are active “one to the right of the block diagonal”, because they are attending to the next key/value token relative to the query.

Interestingly, we can see some clear patterns in the fingerprints. There seems to be a fairly consistent set of heads that are active along the diagonal, and a different set of heads that are consistently active one above the diagonal. Perhaps these are inductions heads!

If you hover or click on the high-intensity elements above the block diagonal (e.g. in the facet in the right column and second-to-bottom row), treescope will show you the named axes that this element corresponds to. Some possible candidates as induction heads:

  • {‘blocks’:5, ‘heads’:0}

  • {‘blocks’:14, ‘heads’:15}

  • {‘blocks’:20, ‘heads’:13}

  • {‘blocks’:21, ‘heads’:2}

  • {‘blocks’:21, ‘heads’:5}

Let’s slice the matrix so we can focus on the off-diagonal we’re interested in:

# Start one after the BOS token
offset = pz.nx.wrap(jnp.arange(1, 21)).tag("offset")
# Query is offset + 20 (the second repetition)
# Key is offset + 1 (the token after the first repetition)
induction_off_diagonal = all_attentions[{"seq": offset + 20, "kv_seq": offset + 1}]
pz.ts.render_array(induction_off_diagonal, vmax=1)

We’re looking for heads that consistently attend along this extracted diagonal. Above, this will look like solid “stripes” along the offset axis, for a particular block and head index.

Something interesting you might observe: Of the heads that seem to activate here, most seem to have a “delay” of one or two tokens before they activate, even though the token at offset 0 is already a repetition. Any guesses why?

Let’s summarize the stripes by taking an average over the offset dimension:

off_diagonal_avg = induction_off_diagonal.untag("offset").mean()
off_diagonal_avg

This is a summary of how “induction-head-like” these attention patterns are. Let’s sort them based on these scores:

# Convert back to positional, flatten, and sort:
positional_avgs = off_diagonal_avg.untag("blocks", "heads").unwrap()
flat_avgs = positional_avgs.reshape([-1])
block_index, head_index = jnp.unravel_index(jnp.argsort(flat_avgs)[::-1], positional_avgs.shape)
for i, (bi, hi) in enumerate(zip(block_index, head_index)):
  val = positional_avgs[bi, hi]
  print(i, "block:", bi, "head:", hi, "score:", val)
  if val < 0.1:
    break

Let’s look at the attention patterns of these candidates:

top_block_indices = pz.nx.wrap(block_index[:10]).tag("best_heads")
top_head_indices = pz.nx.wrap(head_index[:10]).tag("best_heads")
top_attn_patterns = all_attentions[{"blocks": top_block_indices, "heads": top_head_indices}]
pz.ts.render_array(
    top_attn_patterns,
    vmax=1,
    rows=["seq"],
    valid_mask=example_arg.attention_mask,
    axis_item_labels={
        "seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
        "kv_seq": [repr(vocab.IdToPiece(int(t))) for t in tokens],
        # Add info on which candidate this is to the hover tooltip:
        "best_heads": {
          i: f"block {block_index[i]} head {head_index[i]}" for i in range(10)
        },
    }
)

Looks like the first roughly six candidates we identified are crisply attending to the token that should be copied (the one right after the previous repetition), which is what we’d predict an induction head would do. The other ones seem to be attending more fuzzily.

Interestingly, during the first repetition of the sequence, many of these heads also seem to attend to {'seq':12, 'kv_seq':3}, {'seq':15, 'kv_seq':5}, and {'seq':17, 'kv_seq':9}, albeit somewhat weakly. What are these tokens?

If you hover over the cell immediately to the left of them, you’ll see that these are internal repetitions in the original sequence! So these heads are looking tokens that appear after repetitions of the current token, even when that doesn’t help predict the sequence (yet).

     kv_seq:3  seq:12
        ┌────────┐
       ⇣↓        │
<bos> 01976954310149754605 01976954310149754605

       kv_seq:5   seq:15
          ┌─────────┐
         ⇣↓         │
<bos> 01976954310149754605 01976954310149754605

          kv_seq:9  seq:17
              ┌───────┐
             ⇣↓       │
<bos> 01976954310149754605 01976954310149754605

This seems like confirmatory evidence that these are induction heads!

It also suggests a conjecture: perhaps the reason these heads do not immediately activate at token 21 (the first repeated token) and instead start becoming strongly active around token 22 or 23 is that sequence has already included spurious digit repetitions that aren’t copied. Maybe some circuit is detecting this and inhibiting the induction heads until there’s evidence that copying is actually happening.

del all_attentions
gc.collect()

Intervening on attention patterns and activations#

If these are really induction heads, we should expect them to actually copy the value of the token they attend to. How can we test whether or not this is happening?

We could formalize this question in a few different ways. Some concrete questions:

  • If we drop these attention heads from the model entirely, does the model lose its ability to predict these tokens?

  • If we locally perturb these attention heads to make them attend to their target tokens less, does the model’s overall accuracy go down?

  • Assuming these heads are improving accuracy, does that happen because they directly adjust the residual stream in the direction of the correct token, or because they copy information used by later attention heads, or because they copy information used by later MLP layers?

Let’s try to figure it out!

Simple ablation: Knocking out attention heads#

A simple thing we could try would be to just disable a subset of the heads, and see what happens to the model’s predictions. If we’re right that these are induction heads, we should expect that disabling them will reduce the model’s predictive accuracy.

How should we disable a head? One idea would be to zero-out its attention scores, but that might cause the activations to go out of distribution. As an alternative, recall that the (candidate) induction heads above seem to attend preferentially to the beginning-of-sequence token when there is nothing to copy. That suggests that a natural “default value” we could patch in would be to force the head to attend to the beginning-of-sequence token.

(Note that attention heads are a particularly easy part of the model to patch like this! For other intermediate values where there isn’t a sensible default, we might consider swapping in activations from a different input sequence (“activation patching”, Meng et al. 2021) or swapping in an average of activations across many sequences (“mean ablation”, Wang et al. 2022).)

Here’s a layer that will do the job:

@pz.pytree_dataclass
class KnockOutAttentionHeads(pz.Layer):
  """Layer that redirects masked-out heads to attend to BOS.

  Attributes:
    head_mask: NamedArray with 1s for heads we want to keep, and 0s for heads
      that should be rewritten to point to BOS. Values between 0 and 1 will
      smoothly interpolate between them.
  """
  head_mask: pz.nx.NamedArray

  def __call__(self, attn_weights: pz.nx.NamedArray) -> pz.nx.NamedArray:
    knocked_out_attn = pz.nx.wrap(
        jnp.zeros(
            [attn_weights.named_shape["kv_seq"]],
            attn_weights.dtype,
        ).at[0].set(1.0)
    ).tag("kv_seq")
    return knocked_out_attn + self.head_mask * (attn_weights - knocked_out_attn)

We can use it like this:

# Knock out every other head
knockout_layer = KnockOutAttentionHeads(
    head_mask=pz.nx.wrap(jnp.array(
        [1,0,1,0,1,0,1,0,1,0]
    ).astype(jnp.bfloat16)).tag("best_heads")
)

# Show the results (which should alternate kept and knocked-out)
pz.ts.render_array(
    knockout_layer(top_attn_patterns),
    vmax=1,
    rows=["seq"],
    valid_mask=example_arg.attention_mask,
)

Let’s insert it into our model! We’ll use a helper function to automate the process:

def knock_out_heads(model, head_mask_per_block):
  parts = list(head_mask_per_block.untag("blocks"))
  return (
      pz.select(model)
      .at_instances_of(gemma.model_core.GemmaAttention)
      .at_instances_of(pz.nn.Softmax)
      .insert_after("<placeholder>", and_select=True)  # <- Inserting a dummy object and selecting it so we can use `set_sequence`
      .set_sequence(
          KnockOutAttentionHeads(part) for part in parts
      )
  )

Let’s start by knocking out all ten of the possible induction heads we found earlier:

top_heads_mask = pz.nx.wrap(
    jnp.ones(positional_avgs.shape, dtype=jnp.bfloat16)
    .at[block_index[:10], head_index[:10]]
    .set(0.0)
).tag("blocks", "heads")
top_heads_mask
%%autovisualize None

ablated_model = knock_out_heads(model, top_heads_mask)
pz.select(ablated_model).at_instances_of(KnockOutAttentionHeads).at_instances_of(pz.nx.NamedArray).show_value()

Now we can try running it:

ablated_logits = ablated_model(example_arg)
# Map softmax over the vocabulary
ablated_log_probs = pz.nx.nmap(jax.nn.log_softmax)(ablated_logits.untag("vocabulary")).tag("vocabulary")
# Identify correct log probs
ablated_sliced_preds = ablated_log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = token_seq[{"seq": pz.slice[1:]}]
ablated_log_prob_of_correct_next = ablated_sliced_preds[{"vocabulary": correct_next_token}]
ablated_log_prob_of_correct_next
# Log probs (redder is smaller)
token_visualization.show_token_scores(correct_next_token, ablated_log_prob_of_correct_next, vocab)
# Probabilities (bluer is larger)
token_visualization.show_token_scores(correct_next_token, pz.nx.nmap(jnp.exp)(ablated_log_prob_of_correct_next), vocab, vmax=1)

Comparing this to the original accuracies, it’s clear that we’ve severely crippled the ability of the model to imitate the repeating pattern:

token_visualization.show_token_scores(correct_next_token, pz.nx.nmap(jnp.exp)(log_prob_of_correct_next), vocab, vmax=1)

Does this apply to other types of sequence too? Let’s try.

Since our ablated model is an independent copy of our model (only sharing parameters), we are free to call both models on new inputs without having to worry about mutable state.

extra_examples = []

# Some other random digit sequences:
for i in range(100, 104):
  s = "".join(str(i) for i in jax.random.choice(jax.random.key(i), jnp.arange(10), (20,)))
  extra_examples.append(s + s)

# Some other types of repetition:
extra_examples.extend([
    " ".join("injunction double scamp cosmic stroll lucrative" for _ in range(5)),
    "France: Paris, Chile: Santiago, Greece: Athens, China: Beijing, Belgium: Brussels, Norway: Oslo, Canada: Ottawa, Croatia: Zagreb, Algeria: Algiers",
    "The west palace gate has fallen! I repeat, the west palace gate has fallen! We must fall back! I repeat, we must fall back!",
])

# A number the model has probably memorized:
extra_examples.append("3.1415926535897932384626433832795028841")
all_toks = []
for extra_example in extra_examples:
  subtoks = [vocab.bos_id()] + vocab.EncodeAsIds(extra_example)
  subtoks = subtoks + [vocab.pad_id()] * (40 - len(subtoks))
  all_toks.append(subtoks[:40])

new_example_batch = pz.nx.wrap(
    jnp.array(all_toks).astype(jnp.int32)
).tag("batch", "seq")
token_visualization.show_token_array(new_example_batch, vocab)
new_example_arg = gemma.model_core.GemmaInputs.from_basic_segments(new_example_batch)
correct_next_token = new_example_batch[{"seq": pz.slice[1:]}]
orig_logits = model(new_example_arg)
orig_log_probs = pz.nx.nmap(jax.nn.log_softmax)(orig_logits.untag("vocabulary")).tag("vocabulary")
orig_sliced_preds = orig_log_probs[{"seq": pz.slice[:-1]}]
orig_log_prob_of_correct_next = orig_sliced_preds[{"vocabulary": correct_next_token}]

print("Original")
token_visualization.show_token_scores(
    correct_next_token, pz.nx.nmap(jnp.exp)(orig_log_prob_of_correct_next), vocab, vmax=1
)
ablated_logits = ablated_model(new_example_arg)
ablated_log_probs = pz.nx.nmap(jax.nn.log_softmax)(ablated_logits.untag("vocabulary")).tag("vocabulary")
ablated_sliced_preds = ablated_log_probs[{"seq": pz.slice[:-1]}]
ablated_log_prob_of_correct_next = ablated_sliced_preds[{"vocabulary": correct_next_token}]

print("Ablated")
token_visualization.show_token_scores(
    correct_next_token, pz.nx.nmap(jnp.exp)(ablated_log_prob_of_correct_next), vocab, vmax=1
)
# Plot differences:
pz.ts.render_array(
    ablated_log_prob_of_correct_next - orig_log_prob_of_correct_next,
    valid_mask=(correct_next_token != vocab.pad_id()),
)

It looks like knocking out these heads does reduce accuracy for other number sequences. It also somewhat decreases accuracy for a sequence of random words. On the other hand, it seems to have very little effect on the other sentences, suggesting that the model is using a different mechanism for those.

Let’s return to the digit sequences for now. We dropped out all of the attention heads at once; what about the effect of each of them individually?

One way to figure this out would be to make a new ablated model copy with a different attention mask. We could try dropping out individual heads one-at-a-time and see what accuracy we get. Or we could try various combinations of heads.

Another option would be to try an approach inspired by attribution patching (Nanda 2023): approximate the neural network as a linear function of the attention weights (ignoring nonlinear factors), and use automatic differentiation to figure out the gradients of the correct-token loss with respect to our knockout mask. In effect, what this does is tell us “if we increased/decreased this attention head’s influence a little bit, how much would that increase/decrease the final accuracy?”

Luckily, JAX makes this incredibly easy! We can simply ask for the gradient with respect to the patching mask.

def get_ablated_avg_log_prob(head_mask, model, example_arg, loss_mask):
  ablated_model = knock_out_heads(model, head_mask)
  ablated_logits = ablated_model(example_arg)
  ablated_log_probs = pz.nx.nmap(jax.nn.log_softmax)(ablated_logits.untag("vocabulary")).tag("vocabulary")
  ablated_sliced_preds = ablated_log_probs[{"seq": pz.slice[:-1]}]
  correct_next_token = example_arg.tokens[{"seq": pz.slice[1:]}]
  lp_correct = ablated_sliced_preds[{"vocabulary": correct_next_token}]
  lp_correct = pz.nx.nmap(jnp.where)(loss_mask, lp_correct, 0.0)
  return jnp.mean(lp_correct.untag("seq").unwrap())

Let’s start from a fully-unablated model and see how much slightly-ablating each head hurts:

accuracy_grads = jax.jit(jax.grad(get_ablated_avg_log_prob, argnums=0))(
    pz.nx.ones({"blocks": 28, "heads": 16}),
    model,
    example_arg,
    # Focus on improvements to the second repetition
    loss_mask=(pz.nx.arange("seq", 40) > 21)
)
accuracy_grads

Focusing on the induction heads:

print(jnp.max((accuracy_grads * (1-top_heads_mask)).untag("blocks", "heads").unwrap()))
pz.ts.render_array(
    accuracy_grads,
    valid_mask=1-top_heads_mask
)

If we linearize around the un-ablated model, it looks like small changes to the attention weights don’t actually do very much. What if we linearize around the ablated version?

accuracy_grads = jax.jit(jax.grad(get_ablated_avg_log_prob, argnums=0))(
    top_heads_mask,
    model,
    example_arg,
    loss_mask=(pz.nx.arange("seq", 40) > 21)
)
accuracy_grads
print(jnp.max((accuracy_grads * (1-top_heads_mask)).untag("blocks", "heads").unwrap()))
pz.ts.render_array(
    accuracy_grads,
    valid_mask=1-top_heads_mask
)

In contrast, it looks like when we start from the ablated version, then adding back these heads has a big impact: if we add back \(\varepsilon\) of block21-head2, it increases the log-probability of correct tokens by by 1.75 \(\varepsilon\) on average!

As Neel Nanda notes, using a linear approximation can sometimes be misleading, especially if there are “backup heads” that take over when some heads are turned off. One conjecture about why we see such a big difference here is that these different attention heads may be compensating for each other, so that the difference between 9 and 10 heads doesn’t matter much, but the difference between 0 and 1 head is critical.

Exercise for the reader: What’s the smallest set of heads you need to drop out before the model stops being able to copy the integer digits well?

Expand to see the answer

If you run the ablation steps with different masks, you should find that masking out the following four heads causes the model to fail to solve the task:

  • Block 20, head 13 (index 0 of our sorted candidate array)

  • Block 21, head 1 (index 5 of our sorted candidate array)

  • Block 21, head 2 (index 3 of our sorted candidate array)

  • Block 21, head 5 (index 4 of our sorted candidate array)

Adding back any of these heads makes it able to solve the task again!

Interestingly, these four heads are also the four heads with the largest gradients starting from the fully-masked out condition above. So it’s likely that these four heads are indeed compensating for each other.

Path analysis with batched rewiring#

We’ve identified some induction heads, and verified that they are important for accurately copying repeated sequences of integer digits. But we might still wonder, how are these heads interacting with the rest of the model?

  • Are they directly adjust the residual stream in the direction of the correct token?

  • Are they telling other attention heads where to attend to?

  • Are they storing information used by later attention heads?

  • Are they copying information used by later MLP layers?

Distinguishing these requires us to reason about the different computation paths that the model might be using.

In a library like TransformerLens, you might accomplish something like this by caching different activations across different runs of the model, and using mutable hooks to change which values get swapped out at each iteration (e.g. as described in this path-patching tutorial by Callum McDougall). However, in Penzai, we can use our model-editing powers to do this in a more declarative way. Instead of repeatedly running the model with different conditions, we’ll run it once over a parallel set of “counterfactual” states, and add instructions for how to “rewire” them to separate the behavior of different computation paths.

To see how this works, let’s consider a simple subquestion. In very broad strokes, there are two possible ways each attention head could increase the probability of the correct token:

  • It could directly change the residual-stream embedding in a direction that increases the probability of the token,

  • or it could pass information to some later layer using the residual stream, and rely on some later part of the network to adjust the embedding.

How could we distinguish between these? Suppose we made two copies of the residual stream. One copy (the “indirect stream”) could be in charge of receiving the updates from each transformer block and passing them to the other blocks. The other (the “direct stream”) could receive updates from all the other blocks and pass them to the final layer norm and unembedding layer. This decomposition is essentially the “path expansion” trick of Elhage et al. (2021).

Now further suppose that we also made two copies of each block itself. Both copies would read from the indirect stream, but one would write to the direct stream and the other would write back to the indirect stream:

                                                                                                                    
                   Indirect Stream                                                                                  
               ┌────┬─────────────(+)──┬─────────────(+)────────────┬─────────────(+)                               
               │    │              ▲   │              ▲             │              ▲                                
               │    │              │   │              │             │              │                                
   ┌───────┐   │    │  ┌─────────┐ │   │  ┌─────────┐ │             │  ┌─────────┐ │                                
──►│ Input ├───┤    ├──┤ Block 1 │ │   ├──┤ Block 2 │ │             ├──┤ Block L │ │                                
   │ Embed │   │    │  │  Copy A ├─┘   │  │  Copy A ├─┘             │  │  Copy A ├─┘                                
   └───────┘   │    │  └─────────┘     │  └─────────┘               │  └─────────┘                                  
               │    │                  │                   ...      │                                               
               │    │  ┌─────────┐     │  ┌─────────┐               │  ┌─────────┐      ┌───────────┐ ┌─────────┐   
               │    └──┤ Block 1 │     └──┤ Block 2 │               └──┤ Block L │    ┌─┤  Output   ├─┤ Output  ├─►
               │       │  Copy B ├─┐      │  Copy B ├─┐                │  Copy B ├─┐  │ │ LayerNorm │ │ Unembed │   
               │       └─────────┘ │      └─────────┘ │                └─────────┘ │  │ └───────────┘ └─────────┘   
               │                   │                  │                            │  │                             
               │                   ▼                  ▼                            ▼  │                             
               └──────────────────(+)────────────────(+)──────────────────────────(+)─┘                             
                   Direct Stream                                                                                    
                                                                                                                    

We could then separately ablate the induction heads in Copy B to see the effect of removing the direct path, or ablate them in Copy A to see the effect of removing them in the indirect path.

Finally, we can apply one more trick: Let’s think of Copy A and Copy B as being different minibatch elements of the same computation. In other words, let’s add a new length-2 batch axis (let’s call it “worlds”) to our inputs and all of our intermediate values, such that intermediate[{"worlds": 0}] is the value this intermediate would have in the indirect stream or in copy A, and intermediate[{"worlds": 1}] is the value it would have in the direct stream or in copy B. Then we just need to make two changes to our model:

  • When we are applying our KnockOutAttentionHeads layer, we’ll add a “worlds” axis to the mask. If we’re ablating the direct path, mask[{"worlds": 0}] will be 1 everywhere, but mask[{"worlds": 1}] will have zeros at the parts we want to mask out.

  • Inside each residual block, before the layer norm, we’ll add a new “rewiring” step where we copy intermediate[{"worlds": 0}] to intermediate[{"worlds": 1}]. This ensures that the “Copy B” versions still see the indirect stream as input, instead of reading the direct stream.

Because Penzai models let you insert logic anywhere, and all of our model’s layers vectorize over batch axes by name, these are both pretty easy to accomplish! We’ll use a simple helper class to make it a bit clearer what we’re doing.

@dataclasses.dataclass(frozen=True)
class From:
  """A connection between two parallel computations."""
  source: str
  weight: float | pz.nx.NamedArray = 1.0

@pz.pytree_dataclass
class RewireComputationPaths(pz.Layer):
  """Rewires computation across parallel model runs along a "worlds"" axis."""
  worlds_axis: str = dataclasses.field(metadata={"pytree_node": False})
  world_ordering: tuple[str, ...] = dataclasses.field(metadata={"pytree_node": False})
  taking: dict[str, From | tuple[From, ...]] = dataclasses.field(metadata={"pytree_node": False})

  def path_matrix(self) -> pz.nx.NamedArray:
    # Build a matrix that maps the "from" indices to the "to" indices as a
    # linear operation.
    result = [[0 for _ in self.world_ordering] for _ in self.world_ordering]
    assert len(self.taking) == len(self.world_ordering)
    assert set(self.taking.keys()) == set(self.world_ordering)
    for dest, connections in self.taking.items():
      if isinstance(connections, From):
        connections = (connections,)
      for connection in connections:
        from_ix = self.world_ordering.index(connection.source)
        to_ix = self.world_ordering.index(dest)
        result[to_ix][from_ix] += connection.weight
    return pz.nx.nmap(jnp.array)(result)  # <- Allows the weights to be named arrays

  def __call__(self, inputs: pz.nx.NamedArray) -> pz.nx.NamedArray:
    mat = self.path_matrix().astype(inputs.dtype)
    rewired = pz.nx.nmap(jnp.dot)(mat, inputs.untag(self.worlds_axis))
    return rewired.tag(self.worlds_axis)

Here’s how we could use it to rewire both worlds to read from the indirect stream:

read_both_from_indirect = RewireComputationPaths(
    worlds_axis="worlds",
    world_ordering=("indirect", "direct"),
    taking={
        "indirect": From("indirect"),
        "direct": From("indirect"),
    }
)
pz.show(read_both_from_indirect)
pz.show(read_both_from_indirect.path_matrix())
%%autovisualize None
read_both_from_indirect(pz.nx.wrap(jnp.array([123., 456.])).tag("worlds")).untag("worlds").unwrap()

Now let’s insert it into our model. We need to insert a rewiring layer every time the model reads from the residual stream, which in practice means we need to insert it in every residual block.

world_ordering = ("indirect", "direct")
rewired_model = (
    pz.select(model)
    .at_instances_of(pz.nn.Residual)
    .at(lambda r: r.delta.sublayers[0])  # <- assuming each residual contains a Sequential
    .insert_before(
        RewireComputationPaths(
            worlds_axis="worlds",
            world_ordering=world_ordering,
            taking={
                "indirect": From("indirect"),
                "direct": From("indirect"),
            },
        )
    )
)
%%autovisualize None
pz.select(rewired_model).at_instances_of(RewireComputationPaths).show_value()

If we run it right now, it should behave exactly the same as the original model, because we haven’t actually done anything different across the two “worlds”:

rewired_logits = rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(
        pz.nx.stack([token_seq, token_seq], "worlds")
    )
)

rewired_log_probs = pz.nx.nmap(jax.nn.log_softmax)(rewired_logits.untag("vocabulary")).tag("vocabulary")
rewired_sliced_preds = rewired_log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = example_arg.tokens[{"seq": pz.slice[1:]}]
lp_correct = rewired_sliced_preds[{"vocabulary": correct_next_token}]
lp_correct

Now let’s try ablating the heads along one of the paths. We’ll take the minimal set of induction heads we found in the previous section (in the “exercise”), and ablate their direct path; we’ll still allow their output to be processed by the later layers.

per_world_head_mask = pz.nx.wrap(
    jnp.ones(positional_avgs.shape + (2,), dtype=jnp.bfloat16)
    .at[np.array([20,21,21,21]), np.array([13,1,2,5]), 1]
    .set(0.0)
).tag("blocks", "heads", "worlds")
pz.ts.render_array(
    per_world_head_mask, axis_item_labels={"worlds": world_ordering}
)
ablated_rewired_model = knock_out_heads(rewired_model, per_world_head_mask)
ablated_rewired_logits = ablated_rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(
        pz.nx.stack([token_seq, token_seq], "worlds")
    )
)

ablated_rewired_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
    ablated_rewired_logits.untag("vocabulary")
).tag("vocabulary")
ablated_rewired_sliced_preds = ablated_rewired_log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = example_arg.tokens[{"seq": pz.slice[1:]}]
lp_correct = ablated_rewired_sliced_preds[{"vocabulary": correct_next_token}]
pz.nx.nmap(jnp.exp)(lp_correct)

We can see that now there’s a difference between the two “worlds”!

Note that the intention of our experiment here was that only the “direct” path would matter, so it doesn’t make a huge amount of sense to compare the outputs of both paths. However, we could alternatively think of the “indirect” path as being a “unablated” or “clean” path.

The accuracy does seem to decrease a bit relative to the unablated version. But the model still seems to be doing an OK job at copying near the end of the sequence, so perhaps the model is using the indirect path to adjust the logits? Let’s run a larger comparison to get a better sense.

We’ll run it again, but with four parallel worlds:

  • In the “original” world, we won’t do any ablation or rewiring.

  • In the “fully_ablated” world, we’ll ablate the four induction heads, but we won’t do any rewiring.

  • In the “original_ablate_direct” world, we’ll read from the “original” world, but we’ll ablate the induction heads. Since none of the layers read from the “original_ablate_direct” world itself, it represents an ablated direct path from “original” to the output.

  • In the “ablated_restore_direct” world, we’ll read from the “fully_ablated” world, but we’ll turn the induction heads back on. Since none of the layers read from the “ablated_restore_direct” world itself, it represents an unablated direct path from “fully_ablated” to the output.

world_ordering = (
    "original",
    "original_ablate_direct",
    "ablated_restore_direct",
    "fully_ablated",
)

ablate_critical_heads_mask = pz.nx.wrap(
    jnp.ones(positional_avgs.shape, dtype=jnp.bfloat16)
    .at[np.array([20,21,21,21]), np.array([13,1,2,5])]
    .set(0.0)
).tag("blocks", "heads")
unablated_mask = pz.nx.ones({"blocks": 28, "heads": 16})

world_mask_map = {
    "original": unablated_mask,
    "original_ablate_direct": ablate_critical_heads_mask,
    "ablated_restore_direct": unablated_mask,
    "fully_ablated": ablate_critical_heads_mask,
}

per_world_head_mask = pz.nx.stack([
    world_mask_map[world] for world in world_ordering
], "worlds")

pz.ts.render_array(
    per_world_head_mask, axis_item_labels={"worlds": world_ordering}
)
read_rewirer = RewireComputationPaths(
    worlds_axis="worlds",
    world_ordering=world_ordering,
    taking={
        "original": From("original"),
        "original_ablate_direct": From("original"),
        "fully_ablated": From("fully_ablated"),
        "ablated_restore_direct": From("fully_ablated"),
    },
)
rewired_model = (
    pz.select(model)
    .at_instances_of(pz.nn.Residual)
    .at(lambda r: r.delta.sublayers[0])  # <- assuming each residual contains a Sequential
    .insert_before(read_rewirer)
)
ablated_rewired_model = knock_out_heads(rewired_model, per_world_head_mask)

read_rewirer.path_matrix()

Let’s run it!

ablated_rewired_logits = ablated_rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(
        pz.nx.stack([token_seq] * len(world_ordering), "worlds")
    )
)
ablated_rewired_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
    ablated_rewired_logits.untag("vocabulary")
).tag("vocabulary")
ablated_rewired_sliced_preds = ablated_rewired_log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = example_arg.tokens[{"seq": pz.slice[1:]}]
lp_correct = ablated_rewired_sliced_preds[{"vocabulary": correct_next_token}]

pz.ts.render_array(
    lp_correct,
    axis_item_labels={"worlds": world_ordering}
)
pz.ts.render_array(
    pz.nx.nmap(jnp.exp)(lp_correct),
    axis_item_labels={"worlds": world_ordering}
)
print(world_ordering)
token_visualization.show_token_scores(
    correct_next_token,
    lp_correct,
    vocab,
    vmax=3
)

Interestingly, it looks like we again have some sort of compensation behavior; the model is able to make accurate predictions using either path alone.

It also looks like the probabilities may be interacting in a nonlinear way. Let’s instead look at the logits themselves, before the softmax normalization. Since logits are invariant to constant shifts, we’ll compare the logit of the correct answer to the average logit of all digits:

digit_token_ids = pz.nx.wrap(vocab.EncodeAsIds("0123456789")).tag("digits")
logits_per_digit = ablated_rewired_logits[{"seq": pz.slice[:-1]}][{"vocabulary": digit_token_ids}]
logits_per_digit
avg_digit_logit = logits_per_digit.untag("digits").mean()
logit_digit_deltas = (logits_per_digit - avg_digit_logit)
pz.ts.render_array(logit_digit_deltas, axis_item_labels={"worlds": world_ordering})
logits_correct_relative = (
    ablated_rewired_logits[{"seq": pz.slice[:-1]}][{"vocabulary": correct_next_token}]
    - avg_digit_logit
)
logits_correct_relative
print(world_ordering)
token_visualization.show_token_scores(
    correct_next_token,
    logits_correct_relative,
    vocab,
)

Hypothesis: Do these vectors form a parallelogram? In other words, can we linearly decompose the effects of the two paths? Let’s try comparing differences between conditions (e.g. a difference of differences).

If the effects of the direct and indirect paths are independent, we should expect the first two rows to be the same (showing the influence of the direct path), and the second two rows to also be the same (showing the influence of the indirect path):

pz.nx.stack([
    # Influence of the direct path.
    (
        logits_correct_relative[{"worlds": world_ordering.index("original")}]
        - logits_correct_relative[{"worlds": world_ordering.index("original_ablate_direct")}]
    ),
    (
        logits_correct_relative[{"worlds": world_ordering.index("ablated_restore_direct")}]
        - logits_correct_relative[{"worlds": world_ordering.index("fully_ablated")}]
    ),
    # Influence of the indirect path.
    (
        logits_correct_relative[{"worlds": world_ordering.index("original")}]
        - logits_correct_relative[{"worlds": world_ordering.index("ablated_restore_direct")}]
    ),
    (
        logits_correct_relative[{"worlds": world_ordering.index("original_ablate_direct")}]
        - logits_correct_relative[{"worlds": world_ordering.index("fully_ablated")}]
    ),
], "comparison")

It’s not exact, but it’s fairly close. The differences are likely because of remaining nonlinear interactions. In particular:

  • We still have a final layer norm layer, which applies a nonlinear transformation by normalizing by the second moment of its input. That could adjust the logit values.

  • Our induction heads are split over two transformer blocks, block 20 and block 21. This introduces a path that we haven’t accounted for: block 20’s head could send a message to block 21’s heads (through the indirect stream), and then block 21’s heads use this to change their output (through the direct stream). This path is blocked by both of our interventions, because we’re ablating all four heads at once; we either ablate block 20’s message in the indirect stream or the response from block 21’s heads in the direct stream.

Overall, though, it seems that the indirect path is doing more to increase the logit for the correct next digit (relative to the other digits) than the direct path.

Let’s do a bit more analysis to try to decompose this further. A few possible hypotheses for how the indirect path could work:

  • An MLP layer might be amplifying the output of the induction heads.

    • Gemma uses GEGLU MLP layers, which have multiplicative interactions between two sets of features. So we could further decompose this into two cases: it could be amplifying the head’s output in a linear way (e.g. using it only in the linear features) or amplifying it in a nonlinear way (e.g. using it in the “gating” features with a GELU activation).

  • Another attention head might be amplifying the output of induction heads.

    • It could do this by attending to this token and copying its value in a linear way.

    • It could also in principle do this by using the induction head outputs to modify the queries and the keys, changing how information is routed. (This seems a bit unlikely, because the induction heads have already moved the information into the right place, but it’s possible.)

  • Or, it could be some combination of these.

We can try to isolate the effects of these different paths with a more complex rewiring configuration, by progressively enabling more computation paths:

  • In the “fully_ablated” setting, we’ll knock out the induction heads as before.

  • In the “restore_direct” setting, we’ll restore the induction heads, but rewire every layer’s input to read from the “fully_ablated” world. Thus, the only path from those induction heads is the direct path to the output.

  • In the “restore_direct_attnvalue” setting, we’ll rewire the query and key heads of the attention blocks to read from the “fully_ablated” world (freezing the attention pattern), but we’ll allow the value head to act normally. This additionally enables paths that pass through later attention heads without changing their attention pattern.

  • In the “restore_direct_attnall” setting, we’ll further allow the query and key projections to see the output of the induction heads, allowing attention patterns to change.

  • In the “restore_direct_attnall_linmlp” setting, we’ll start with “restore_direct_attnall”, but we’ll linearize the MLP layers around their values in the “fully_ablated” setting, and then evaluate them including the output of the induction heads (and of the attention circuits). This additionally enables linear paths through the MLP layers.

  • Finally, as before, we’ll have an “original” setting where nothing is rewired to read from “fully_ablated”, and both linear and nonlinear computation paths are included.

By comparing the accuracies of each of these steps, we should be able to tell roughly how much adding each type of path improves the model’s accuracy.

How should we linearize the MLP layers? One way to do this would be to capture intermediates, then build this linear approximation by hand. But there’s an easier way, by combining Penzai’s compositionality with JAX’s function transformations. We’ll use the following combinator layer, which splits its input into two, preprocesses each copy, linearizes its child layer around the first copy, and then evaluates that linear approximation at the second:

@pz.pytree_dataclass
class LinearizeAndAdjust(pz.Layer):
  linearize_around: pz.LayerLike
  evaluate_at: pz.LayerLike
  target: pz.LayerLike

  def __call__(self, inputs):
    primal_point = self.linearize_around(inputs)
    eval_point = self.evaluate_at(inputs)
    # f(b) ~= f(a) + (b-a) f'(a)
    tangent_in = jax.tree_util.tree_map(
        lambda ppt, ept: (ept - ppt).order_like(ppt),
        primal_point,
        eval_point,
        is_leaf=pz.nx.is_namedarray,
    )
    primal_out, tangent_out = jax.jvp(
        self.target, (primal_point,), (tangent_in,)
    )
    return jax.tree_util.tree_map(
        lambda p_out, t_out: p_out + t_out,
        primal_out,
        tangent_out,
        is_leaf=pz.nx.is_namedarray,
    )

We can now perform our analysis by applying a sequence of structural patches to the model, inserting rewiring and linearization points one at a time as needed.

world_ordering = (
    "fully_ablated",
    "restore_direct",
    "restore_direct_attnvalue",
    "restore_direct_attnall",
    "restore_direct_attnall_linmlp",
    "original",
)
# Set up the ablation of our attention heads, as before.
ablate_critical_heads_mask = pz.nx.wrap(
    jnp.ones(positional_avgs.shape, dtype=jnp.bfloat16)
    .at[np.array([20,21,21,21]), np.array([13,1,2,5])]
    .set(0.0)
).tag("blocks", "heads")
unablated_mask = pz.nx.ones({"blocks": 28, "heads": 16})

world_mask_map = {
    "fully_ablated": ablate_critical_heads_mask,
    "restore_direct": unablated_mask,
    "restore_direct_attnvalue": unablated_mask,
    "restore_direct_attnall": unablated_mask,
    "restore_direct_attnall_linmlp": unablated_mask,
    "original": unablated_mask,
}
per_world_head_mask = pz.nx.stack([
    world_mask_map[world] for world in world_ordering
], "worlds")

pz.ts.render_array(
    per_world_head_mask, axis_item_labels={"worlds": world_ordering}
)
# Start with the original unmodified model checkpoint.
rewired_model = model

# Linearize the final output layer norm around the ablated input.
# This isn't strictly necessary, but it shouldn't affect the relative
# differences between the logits much, and we aren't particularly interested in
# the effect of this layer norm.
rewired_model = (
    pz.select(rewired_model)
    .at((lambda root: root.body.body.body.sublayers[-2]))
    .apply(lambda layernorm_layer: LinearizeAndAdjust(
        # Linearize around fully_ablated always
        linearize_around=RewireComputationPaths(
            worlds_axis="worlds",
            world_ordering=world_ordering,
            taking={
                k: From("fully_ablated") for k in world_ordering
            },
        ),
        # But evaluate it at each world's own input.
        evaluate_at=pz.nn.Identity(),
        target=layernorm_layer,
    ))
)

# Knock out the attention heads using the mask we defined above.
rewired_model = knock_out_heads(rewired_model, per_world_head_mask)

# Rewire the attention queries and keys.
rewired_model = (
    pz.select(rewired_model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at(lambda attn: (attn.input_to_query.sublayers[0], attn.input_to_key.sublayers[0]))
    .insert_before(RewireComputationPaths(
        worlds_axis="worlds",
        world_ordering=world_ordering,
        taking={
            # Ablating the induction head -> attention pattern paths
            "fully_ablated": From("fully_ablated"),
            "restore_direct": From("fully_ablated"),
            "restore_direct_attnvalue": From("fully_ablated"),
            # Restoring the induction head -> attention pattern paths
            "restore_direct_attnall": From("restore_direct_attnall"),
            "restore_direct_attnall_linmlp": From("restore_direct_attnall_linmlp"),
            "original": From("original"),
        },
    ))
)

# Rewire the attention values.
rewired_model = (
    pz.select(rewired_model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at(lambda attn: attn.input_to_value.sublayers[0])
    .insert_before(RewireComputationPaths(
        worlds_axis="worlds",
        world_ordering=world_ordering,
        taking={
            # Ablating the induction head -> attention value paths
            "fully_ablated": From("fully_ablated"),
            "restore_direct": From("fully_ablated"),
            # Restoring the induction head -> attention value paths
            "restore_direct_attnvalue": From("restore_direct_attnvalue"),
            "restore_direct_attnall": From("restore_direct_attnall"),
            "restore_direct_attnall_linmlp": From("restore_direct_attnall_linmlp"),
            "original": From("original"),
        },
    ))
)

# Linearize and rewire the MLP blocks.
rewired_model = (
    pz.select(rewired_model)
    .at_instances_of(gemma.model_core.GemmaFeedForward)
    .apply(lambda mlp: LinearizeAndAdjust(
        linearize_around=RewireComputationPaths(
            worlds_axis="worlds",
            world_ordering=world_ordering,
            taking={
                # Ablating the induction head -> MLP nonlinear paths
                "fully_ablated": From("fully_ablated"),
                "restore_direct": From("fully_ablated"),
                "restore_direct_attnvalue": From("fully_ablated"),
                "restore_direct_attnall": From("fully_ablated"),
                "restore_direct_attnall_linmlp": From("fully_ablated"),
                # Restoring the induction head -> MLP nonlinear paths
                "original": From("original"),
            },
        ),
        evaluate_at=RewireComputationPaths(
            worlds_axis="worlds",
            world_ordering=world_ordering,
            taking={
                # Ablating the induction head -> MLP linear paths
                "fully_ablated": From("fully_ablated"),
                "restore_direct": From("fully_ablated"),
                "restore_direct_attnvalue": From("fully_ablated"),
                "restore_direct_attnall": From("fully_ablated"),
                # Restoring the induction head -> MLP linear paths
                "restore_direct_attnall_linmlp": From("restore_direct_attnall_linmlp"),
                "original": From("original"),
            },
        ),
        target=mlp,
    ))
)

Let’s look at the changes to make sure we patched the correct parts of the model:

%%autovisualize None
pz.select(rewired_model).at_instances_of(RewireComputationPaths)

Looks right! Now we can run it:

rewired_logits = rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(
        token_seq.broadcast_to(named_shape={"worlds": len(world_ordering)})
    )
)
correct_next_token = example_arg.tokens[{"seq": pz.slice[1:]}]
digit_token_ids = pz.nx.wrap(vocab.EncodeAsIds("0123456789")).tag("digits")
logits_per_digit = rewired_logits[{"seq": pz.slice[:-1]}][{"vocabulary": digit_token_ids}]
avg_digit_logit = logits_per_digit.untag("digits").mean()
logit_digit_deltas = (logits_per_digit - avg_digit_logit)
pz.ts.render_array(logit_digit_deltas, axis_item_labels={"worlds": world_ordering})
logits_correct_relative = (
    rewired_logits[{"seq": pz.slice[:-1]}][{"vocabulary": correct_next_token}]
    - avg_digit_logit
)
logits_correct_relative

Summarizing the average relative logits over the second repetition:

{
    worldname: float(
        logits_correct_relative[{"worlds": i, "seq": pz.slice[21:]}]
        .untag("seq").mean().unwrap()
    )
    for i, worldname in enumerate(world_ordering)
}

To summarize:

  • Just restoring the direct path increases the (relative) logit of the correct answers a fair amount (by about 1.5), as we’ve seen before.

  • Additionally restoring the attention-value path yields another increase of about 1.1.

  • However, if we also restore the query-key circuit’s effect on the attention patterns, the average logit score actually decreases by about 0.9. This suggests that changes to later attention patterns modulate down the effect of the induction heads.

    • For this experiment’s ablation masks, we’re restoring all four heads, so this could have something to do with the interaction between block 20 and block 21 that we mentioned previously. Perhaps block 21’s heads compensate for block 20’s head being inactive, and copy less strongly if it’s active.

  • When we restore the linear path throught the MLPs, we see a large increase in the logit score (about 3.7).

    • This suggests that the MLPs are very sensitive to the copied value in the ablated setting.

  • When we restore the nonlinear path as well, this drops down again by about 1.7.

    • Also, though the logits across the tokens become “smoother” across the sequence, varying less from token to token. For instance, the token at index 36 has an unusually small logit value in the ablated settings, but has a fairly normal predicted logit in the fully-restored condition.

This suggests that both the attention-value circuits and linearized MLP paths are set up to amplify the values read by the induction heads, and then the query-key interactions and nonlinear MLP components are used to dampen the effect of these paths and calibrate their predictions, perhaps compensating for the redundancies between the four heads we found.

Activation patching and counterfactual inputs#

With only a minor change to this batched rewiring setup, we can also perform activation patching (Meng et al. 2021): copying a single activation between two different inputs to see if we can edit the model’s behavior.

If these are induction heads, we should expect that they are raising the probability of copying the token they attend to. We can test this out by constructing two sequences:

  • Sequence 1 will be a repeated sequence like we’ve used so far.

  • Sequence 2 will be a different random sequence without repetition.

We can then run the model on sequence 1, but intervene on the value projections of our induction heads, so that they actually take the values from sequence 2. If we’re right about how this circuit works, we should expect that the model will attempt to copy the prefix of sequence 2 as the completion of sequence 1.

Let’s try it out. We’ll use a similar strategy to our rewiring before. Our rewiring will have three “worlds”:

  • In “original”, we’ll feed sequence 1 in normally.

  • In “nonrepeating”, we’ll feed sequence 2 in normally.

  • In “patched_induction_values”, we’ll feed in sequence 1, but we’ll patch in the value heads from sequence 2 for the induction heads only. (Note that, since we’re not sampling from the model, the model’s output predictions don’t make their way back to the inputs.)

The difference from our earlier rewirings: we won’t be knocking out any induction heads; we’ll leave them as-is. Instead, we’ll change our input so that it’s different in the different worlds.

Let’s start by preparing the sequences:

world_ordering = ("original", "nonrepeating", "patched_induction_values")
counterfactuals = [
    "01976954310149754605" + "01976954310149754605",  # <- Our running example so far. This one will be a reference.
    "67717010284911166217" + "06302739717444079179",  # <- A counterfactual non-repeating sequence.
    "01976954310149754605" + "01976954310149754605",  # <- The first example again, but we'll patch this one's activations.
]
all_toks = []
for cf_example in counterfactuals:
  subtoks = [vocab.bos_id()] + vocab.EncodeAsIds(cf_example)
  all_toks.append(subtoks)

counterfactuals_batch = pz.nx.wrap(
    jnp.array(all_toks).astype(jnp.int32)
).tag("worlds", "seq")  # <- Name it using the same "worlds" axis convention.

token_visualization.show_token_array(counterfactuals_batch, vocab)

Now let’s rewire the heads. To do that, we’ll need a matrix that tells us which world each head should read from. We’ll do it manually for each head this time (although we could also write a helper function like knock_out_heads if we were planning to try a bunch of different ablations).

block_20_induction_heads = pz.nx.wrap(jnp.zeros([16], dtype=jnp.bool_).at[13].set(True)).tag("heads")
block_21_induction_heads = pz.nx.wrap(jnp.zeros([16], dtype=jnp.bool_).at[np.array([1,2,5])].set(True)).tag("heads")
# Start with the original unmodified model checkpoint.
rewired_model = model

# Rewire the attention values in block 20 and 21:
for block_index, induction_head_mask in [
    (20, block_20_induction_heads),
     (21, block_21_induction_heads),
]:
  rewired_model = (
      pz.select(rewired_model)
      .at_instances_of(gemma.model_core.GemmaTransformerBlock)
      .assert_count_is(28)
      .pick_nth_selected(block_index)
      .at_instances_of(gemma.model_core.GemmaAttention)
      .at(lambda attn: attn.input_to_value.sublayers[-1])  # <- the value projections
      .assert_count_is(1)
      .insert_after(RewireComputationPaths(
          worlds_axis="worlds",
          world_ordering=world_ordering,
          taking={
              "original": From("original"),
              "nonrepeating": From("nonrepeating"),
              "patched_induction_values": (
                  # The induction heads read from "nonrepeating".
                  From("nonrepeating", weight=pz.nx.nmap(jnp.where)(induction_head_mask, 1., 0.)),
                  # Everything other than the induction heads take values from "patched_induction_values".
                  From("patched_induction_values", weight=pz.nx.nmap(jnp.where)(induction_head_mask, 0., 1.)),
              ),
          },
      ))
  )
%%autovisualize None
pz.select(rewired_model).at_instances_of(RewireComputationPaths)

Now we can run it:

paired_logits = rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(counterfactuals_batch)
)
# Let's look at the token probabilities of each digit:
all_probs = pz.nx.nmap(jax.nn.softmax)(paired_logits.untag("vocabulary")).tag("vocabulary")
digit_token_ids = pz.nx.wrap(vocab.EncodeAsIds("0123456789")).tag("digits")
digit_probs = all_probs[{"vocabulary": digit_token_ids}]

pz.ts.render_array(digit_probs, axis_item_labels={"worlds": world_ordering}, vmax=1)

In the “patched_induction_values” setting (the third facet), we see that the model is confidently predicting digits, but they are wrong! Let’s compare with the one-hot encodings of the inputs.

pz.ts.render_array(counterfactuals_batch == digit_token_ids, axis_item_labels={"worlds": world_ordering}, vmax=1)

As we predicted, the model outputs for the second half of the “patched_induction_values” condition (bottom right) are copying the digits from the first half of the “nonrepeating” input (middle left)! The only way this information could have been transferred is through our rewiring layers, so this means we’ve succesfully intervened on the “source” of the model’s copying information.

We can just as easily try alternative variants. For instance, let’s try running the model on the “nonrepeating” sequence, but patch in the attention pattern from the repeating “original” sequence, and see if it tries to copy the prefix of the “nonrepeating” sequence even though there’s no reason to.

world_ordering = ("original", "nonrepeating", "patched_ind_attn_pattern")
counterfactuals = [
    "01976954310149754605" + "01976954310149754605",
    "67717010284911166217" + "06302739717444079179",
    "67717010284911166217" + "06302739717444079179",  # <- Patching the non-repeating sequence this time.
]
all_toks = []
for cf_example in counterfactuals:
  subtoks = [vocab.bos_id()] + vocab.EncodeAsIds(cf_example)
  all_toks.append(subtoks)

counterfactuals_batch = pz.nx.wrap(
    jnp.array(all_toks).astype(jnp.int32)
).tag("worlds", "seq")  # <- Name it using the same "worlds" axis convention.
# Start with the original unmodified model checkpoint.
rewired_model = model

# Rewire the attention values in block 20 and 21:
for block_index, induction_head_mask in [
    (20, block_20_induction_heads),
    (21, block_21_induction_heads),
]:
  rewired_model = (
      pz.select(rewired_model)
      .at_instances_of(gemma.model_core.GemmaTransformerBlock)
      .assert_count_is(28)
      .pick_nth_selected(block_index)
      .at_instances_of(gemma.model_core.GemmaAttention)
      .at(lambda attn: attn.query_key_to_attn.sublayers[-1])  # <- the softmax
      .assert_count_is(1)
      .insert_after(RewireComputationPaths(
          worlds_axis="worlds",
          world_ordering=world_ordering,
          taking={
              "original": From("original"),
              "nonrepeating": From("nonrepeating"),
              "patched_ind_attn_pattern": (
                  # The induction head patterns are patched from "original".
                  From("original", weight=pz.nx.nmap(jnp.where)(induction_head_mask, 1., 0.)),
                  # Every other pattern is kept from "patched_ind_attn_pattern".
                  From("patched_ind_attn_pattern", weight=pz.nx.nmap(jnp.where)(induction_head_mask, 0., 1.)),
              ),
          },
      ))
  )
%%autovisualize None
pz.select(rewired_model).at_instances_of(RewireComputationPaths)

Let’s see what we get:

paired_logits = rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(counterfactuals_batch)
)
# Let's look at the token probabilities of each digit:
all_probs = pz.nx.nmap(jax.nn.softmax)(paired_logits.untag("vocabulary")).tag("vocabulary")
digit_token_ids = pz.nx.wrap(vocab.EncodeAsIds("0123456789")).tag("digits")
digit_probs = all_probs[{"vocabulary": digit_token_ids}]

pz.ts.render_array(digit_probs, axis_item_labels={"worlds": world_ordering}, vmax=1)

It looks like patching the attention patterns of the induction heads is not enough to get it to start copying confidently!. However, there’s still a very faint pattern of copied digits in the third condition if you look closely. We can take a difference in log-probs to better emphasize this:

all_log_probs = pz.nx.nmap(jax.nn.softmax)(paired_logits.untag("vocabulary")).tag("vocabulary")
digit_log_probs = all_log_probs[{"vocabulary": digit_token_ids}]
diffs = digit_log_probs[{"worlds": 2}] - digit_log_probs[{"worlds": 1}]

pz.show(
    "Relevant tokens to copy:",
    pz.ts.render_array(
        counterfactuals_batch[{"worlds": 2, "seq": pz.slice[:21]}] == digit_token_ids,
        vmax=1,
    )
)
pz.show(
    "Differences between 'patched_ind_attn_pattern' and 'nonrepeating' log probs:",
    pz.ts.render_array(diffs, vmax=0.07)
)

What do these results mean? One conjecture is that there are separate copy-detection and copy-location circuits in this model:

  • Some circuit before block 20 is responsible for determining whether or not this sequence looks like it’s copying.

  • If this feature is active, these induction heads copy the token value, and it is amplified and used as the model’s prediction.

  • However, if this feature is inactive, the induction heads don’t copy the value they attend to.

    • Perhaps that value is zeroed out in the subspace of the input that they attend to.

    • Or, perhaps that value is copied into an output value subspace, but some later MLP layers scrub it out before it arrives at the final unembedding layer.

Let’s test these last two hypotheses by coping both the attention pattern and the values of the attention head, instead of just copying the attention pattern. If the zeroing-out is happening in the input subspace, we’d expect copying the full attention output to restore the copying behavior (but it would now copy the first half of the first sequence, instead of copying from the nonrepeating sequence). But if the zeroing-out is happening using MLPs later in the model, we’d expect copying the full attention output to not make much difference.

world_ordering = ("original", "nonrepeating", "patched_ind_attn_output")
# Start with the original unmodified model checkpoint.
rewired_model = model

# Rewire the attention values in block 20 and 21:
for block_index, induction_head_mask in [
    (20, block_20_induction_heads),
    (21, block_21_induction_heads),
]:
  rewirer_layer = RewireComputationPaths(
      worlds_axis="worlds",
      world_ordering=world_ordering,
      taking={
          "original": From("original"),
          "nonrepeating": From("nonrepeating"),
          "patched_ind_attn_output": (
              # The induction heads are patched from "original".
              From("original", weight=pz.nx.nmap(jnp.where)(induction_head_mask, 1., 0.)),
              # Every other value is kept from "patched_ind_attn_output".
              From("patched_ind_attn_output", weight=pz.nx.nmap(jnp.where)(induction_head_mask, 0., 1.)),
          ),
      },
  )
  rewired_model = (
      pz.select(rewired_model)
      .at_instances_of(gemma.model_core.GemmaTransformerBlock)
      .assert_count_is(28)
      .pick_nth_selected(block_index)
      .at_instances_of(gemma.model_core.GemmaAttention)
      .at(lambda attn: attn.input_to_value.sublayers[-1])  # <- the value projection
      .assert_count_is(1)
      .insert_after(rewirer_layer)
  )
  rewired_model = (
      pz.select(rewired_model)
      .at_instances_of(gemma.model_core.GemmaTransformerBlock)
      .assert_count_is(28)
      .pick_nth_selected(block_index)
      .at_instances_of(gemma.model_core.GemmaAttention)
      .at(lambda attn: attn.query_key_to_attn.sublayers[-1])  # <- the softmax
      .assert_count_is(1)
      .insert_after(rewirer_layer)
  )
%%autovisualize None
pz.select(rewired_model).at_instances_of(RewireComputationPaths)
paired_logits = rewired_model(
    gemma.model_core.GemmaInputs.from_basic_segments(counterfactuals_batch)
)
all_probs = pz.nx.nmap(jax.nn.softmax)(paired_logits.untag("vocabulary")).tag("vocabulary")
digit_token_ids = pz.nx.wrap(vocab.EncodeAsIds("0123456789")).tag("digits")
digit_probs = all_probs[{"vocabulary": digit_token_ids}]

pz.ts.render_array(digit_probs, axis_item_labels={"worlds": world_ordering}, vmax=1)

Looks like this still doesn’t produce copying behavior. So, the second hypothesis is more likely to be true: there are probably MLP layers that are scrubbing out the contributions of this attention head, or at least not amplifying it.

In fact, this is consistent with the behavior we observed in the previous section! We observed that the induction heads were being modulated in a nonlinear fashion by the MLP layers, and that those MLP layers were locally highly sensitive to the output of the induction heads.

Recap#

Summarizing what we’ve found:

  • When the model is given random integer digit sequences, it looks like there are induction heads in blocks 5, 14, 20, and 21. By ablating them, we determined that preserving at least one head from blocks 20 and 21 is both necessary and sufficient for the model to confidently copy digits from its input.

  • Ablating individual linear and nonlinear paths through the model revealed that there are contributions through the direct output of the layer, through later attention heads, and through MLPs. However, there is a nonlinear modulating effect from both later attention patterns and the nonlinear MLP activations.

  • Running activation patching between counterfactual inputs showed that, if the model has decided to copy, we can change what it copies by patching in the induction head value projections. This is causal evidence that the model is using those value projections to make its prediction.

  • On the other hand, we can’t seem to make the model decide to copy by patching in the attention pattern of the induction heads, or by patching in both the attention pattern and the copied values. This suggests that the MLP layers may be playing a role in “gating” these induction heads, and determining whether those outputs make their way to the prediction.

Along the way, we’ve demonstrated how to:

  • Look at all sorts of high-dimensional named-axis arrays using Penzai’s autovisualizer and pz.ts.render_array, optionally adding useful annotations of our own to the tooltips,

  • Look at token sequences with the token_visualization tool

  • Visualize the structure of the pretrained Gemma model

  • Inject new logic into that model using pz.select

  • Use the named axis system to identify likely induction heads

  • Write our own patching logic to knock out attention heads

  • Run complex path-rewiring experiments in vectorized form by adding a “worlds” axis

  • Linearize MLPs and layer norm blocks around their activations in different parallel “worlds”

  • And combine these to perform causal interventions on different activations within the model and between different inputs.

Overall, this notebook demonstrates the flexibility of the Penzai toolkit, and how easy it is to compose different ablations, reroute various activation paths in a declarative way, and quickly iterate on our analysis without having to manage mutable state or manually cache different model activations. It’s also worth emphasizing again that this model is sharded across multiple accelerator devices using the power of the XLA compiler, and that the computation across our parallel counterfactual “worlds” is also fully vectorized and parallelizable.

There’s clearly more interesting behavior to explore about these circuits, but we’ll stop here and leave this exploration to the reader. Some interesting questions you might explore:

  • Can you design a set of rewirings that would identify whether the head in block 20 is indeed influencing the attention patterns in block 21?

  • Why is restoring any one of these heads is sufficient to restore the overall copying behavior we’ve observed?

  • What are the other induction heads (in block 5 and 14) doing, given that they aren’t enough to implement the copying behavior on their own?

  • What part of the model is responsible for deciding when it should copy and when it shouldn’t copy? Can you figure out a minimal set of activations we can patch in from the repeating sequence to trick the model into copying the non-repeating sequence?

We also note that the components KnockOutAttentionHeads, RewireComputationPaths, and LinearizeAndAdjust used in this notebook are also available in Penzai under penzai.toolshed.model_rewiring, so you can use them in your own notebooks without copying their definitions. But don’t feel limited by those layers! Penzai makes it easy to add custom logic to the models, so you can perform whatever modification you want. (If you want to try something that involves mutable state, you might also consider reading the separate data-effects tutorial.)