Copyright 2024 The Penzai Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Induction Heads in Gemma 2B#
Note: This version of this tutorial uses the 2-billion parameter Gemma model, which has a smaller memory footprint. See also the “Induction Heads in Gemma 7B” tutorial.
One of Penzai’s primary goals is to support interpretability research on state-of-the-art models. In this notebook, we’ll use Penzai to try to find and intervene on induction heads (Elhage et al. 2021, Olsson et al. 2022) in the Gemma 2B open-weights model. We’ll be focusing on exploratory analysis and on Penzai’s tooling rather than on rigor; the goal is to show how you can use Penzai to quickly prototype ideas and generate hypotheses about network behavior (not to perfectly measure the presence of induction heads or exactly reproduce previous results).
Along the way, we’ll discuss:
How to use JAX’s sharding support to automatically shard the model over a cluster of TPUs,
How to use Penzai’s pretty-printer (Treescope) to quickly look at model weights and activations,
How to extract intermediate values and intermediate subcomputations from a larger model for detailed analysis, using either Penzai’s manual patching tool
pz.select
or using Penzai’s data-effect system,How to use Penzai’s named axis library to identify the characteristic patterns of induction heads,
And how to patch the Gemma model by intervening on intermediate subcomputations (in this case, the attention weights),
Let’s get started!
Setting up and loading the model#
We’ll start by setting up the environment and loading the Gemma 2B model.
Imports#
To run this notebook, you need a Python environment with penzai
and its dependencies installed.
In Colab or Kaggle, you can install it using the following command:
try:
import penzai
except ImportError:
!pip install penzai[notebook]
from __future__ import annotations
from typing import Any
import os
import dataclasses
import gc
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
from jax.experimental import mesh_utils
import sentencepiece as spm
# Allow using ~all GPU memory if using a Colab GPU kernel.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".98"
import treescope
import penzai
from penzai import pz
from penzai.models import transformer
Setting up Penzai#
For this tutorial, we’ll enable Treescope (Penzai’s companion pretty-printer) as the default Colab pretty-printer. We’ll also turn on automatic visualization of JAX and Numpy arrays. This will make it easy to look at our models and their outputs.
treescope.basic_interactive_setup(autovisualize_arrays=True)
Loading Gemma#
Next we’ll load the weights from the Gemma checkpoint.
We’ll use the 2B checkpoint for this tutorial. The 2B checkpoint can be loaded into either a “TPU v2” or “T4 GPU” Colab kernel.
Note: Colab’s “TPU v2” kernel, as well as Colab’s advanced GPU kernels (e.g. “A100”) and Kaggle’s TPU kernels, have enough memory to load the 7B Gemma checkpoint as well. If you want to follow along with the 7B model, you can instead load the “Induction Heads in Gemma 7B” tutorial. (Both tutorials cover the same material, but the locations of the induction heads and some aspects of the model predictions differ between the variants!)
You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:
Visit https://www.kaggle.com/ and create an account if needed.
Go to your account settings, then the ‘API’ section.
Click ‘Create new token’ to download your key.
Next, if you are running this notebook in Google Colab:
Click the “key” symbol on the left toolbar to open the “Secrets” tab.
Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.
Run the cell below and grant this notebook access to the secrets you just made.
If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.
import kagglehub
try:
from google.colab import userdata
kagglehub.config.set_kaggle_credentials(
userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
)
except ImportError:
kagglehub.login()
Kaggle credentials set.
If everything went well, you should see:
Kaggle credentials set.
Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:
https://www.kaggle.com/models/google/gemma/license/consent
(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)
Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:
weights_dir = kagglehub.model_download('google/gemma/Flax/2b')
ckpt_path = os.path.join(weights_dir, '2b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.3.0)
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.3.0)
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/ocdbt.process_0/d/d762106b6d0cf1ce0d139bb065408813...
0%| | 0.00/2.89G [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/ocdbt.process_0/manifest.0000000000000003...
0%| | 0.00/206 [00:00<?, ?B/s]
100%|██████████| 206/206 [00:00<00:00, 268kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/ocdbt.process_0/manifest.0000000000000001...
0%| | 0.00/86.0 [00:00<?, ?B/s]
100%|██████████| 86.0/86.0 [00:00<00:00, 139kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/_METADATA...
0%| | 0.00/29.7k [00:00<?, ?B/s]
100%|██████████| 29.7k/29.7k [00:00<00:00, 23.7MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/ocdbt.process_0/manifest.0000000000000002...
0%| | 0.00/150 [00:00<?, ?B/s]
100%|██████████| 150/150 [00:00<00:00, 223kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/d/e554913b217c639b384edd1b2c24a0f1...
0%| | 0.00/24.9k [00:00<?, ?B/s]
100%|██████████| 24.9k/24.9k [00:00<00:00, 19.9MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/ocdbt.process_0/manifest.ocdbt...
0%| | 0.00/56.0 [00:00<?, ?B/s]
100%|██████████| 56.0/56.0 [00:00<00:00, 47.1kB/s]
0%| | 1.00M/2.89G [00:00<09:11, 5.62MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/ocdbt.process_0/d/8a70e3c8448fe433597c114c4af18ef3...
0%| | 0.00/788M [00:00<?, ?B/s]
0%| | 8.00M/2.89G [00:00<01:28, 35.0MB/s]
0%| | 13.0M/2.89G [00:00<01:15, 40.9MB/s]
0%| | 1.00M/788M [00:00<02:28, 5.57MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/checkpoint...
0%| | 0.00/11.9k [00:00<?, ?B/s]
100%|██████████| 11.9k/11.9k [00:00<00:00, 11.0MB/s]
1%| | 19.0M/2.89G [00:00<01:14, 41.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/manifest.ocdbt...
0%| | 0.00/55.0 [00:00<?, ?B/s]
100%|██████████| 55.0/55.0 [00:00<00:00, 202kB/s]
1%| | 9.00M/788M [00:00<00:26, 30.4MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/manifest.0000000000000002...
0%| | 0.00/147 [00:00<?, ?B/s]
100%|██████████| 147/147 [00:00<00:00, 229kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/2b/manifest.0000000000000001...
0%| | 0.00/85.0 [00:00<?, ?B/s]
100%|██████████| 85.0/85.0 [00:00<00:00, 133kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/2b/2/download/tokenizer.model...
0%| | 0.00/4.04M [00:00<?, ?B/s]
1%| | 26.0M/2.89G [00:00<01:41, 30.3MB/s]
2%|▏ | 13.0M/788M [00:00<00:40, 19.9MB/s]
25%|██▍ | 1.00M/4.04M [00:00<00:00, 5.62MB/s]
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.7MB/s]
3%|▎ | 20.0M/788M [00:00<00:32, 25.1MB/s]
1%| | 33.0M/2.89G [00:01<01:50, 27.7MB/s]
3%|▎ | 25.0M/788M [00:01<00:31, 25.4MB/s]
1%|▏ | 41.0M/2.89G [00:01<01:35, 31.9MB/s]
4%|▍ | 32.0M/788M [00:01<00:23, 34.2MB/s]
5%|▌ | 40.0M/788M [00:01<00:17, 44.7MB/s]
2%|▏ | 49.0M/2.89G [00:01<01:27, 34.9MB/s]
2%|▏ | 57.0M/2.89G [00:01<01:26, 35.2MB/s]
6%|▌ | 46.0M/788M [00:01<00:23, 33.2MB/s]
2%|▏ | 65.0M/2.89G [00:01<01:14, 40.7MB/s]
7%|▋ | 54.0M/788M [00:01<00:18, 42.0MB/s]
8%|▊ | 64.0M/788M [00:01<00:14, 52.7MB/s]
9%|▉ | 71.0M/788M [00:01<00:14, 53.3MB/s]
2%|▏ | 73.0M/2.89G [00:02<01:23, 36.3MB/s]
3%|▎ | 81.0M/2.89G [00:02<01:16, 39.3MB/s]
10%|▉ | 77.0M/788M [00:02<00:16, 45.5MB/s]
11%|█ | 83.0M/788M [00:02<00:15, 48.7MB/s]
11%|█▏ | 89.0M/788M [00:02<00:16, 44.6MB/s]
3%|▎ | 89.0M/2.89G [00:02<01:28, 34.1MB/s]
3%|▎ | 97.0M/2.89G [00:02<01:26, 34.7MB/s]
12%|█▏ | 97.0M/788M [00:02<00:20, 35.9MB/s]
13%|█▎ | 105M/788M [00:02<00:17, 40.4MB/s]
4%|▎ | 105M/2.89G [00:03<01:22, 36.1MB/s]
4%|▍ | 116M/2.89G [00:03<01:02, 47.6MB/s]
14%|█▍ | 113M/788M [00:03<00:15, 45.7MB/s]
4%|▍ | 122M/2.89G [00:03<01:16, 39.1MB/s]
15%|█▌ | 121M/788M [00:03<00:17, 39.8MB/s]
16%|█▋ | 129M/788M [00:03<00:15, 44.5MB/s]
17%|█▋ | 137M/788M [00:03<00:13, 49.4MB/s]
4%|▍ | 129M/2.89G [00:03<01:31, 32.5MB/s]
18%|█▊ | 145M/788M [00:03<00:12, 53.8MB/s]
19%|█▉ | 153M/788M [00:03<00:11, 57.5MB/s]
5%|▍ | 137M/2.89G [00:04<01:35, 31.0MB/s]
20%|██ | 161M/788M [00:03<00:11, 55.7MB/s]
5%|▍ | 145M/2.89G [00:04<01:26, 34.0MB/s]
21%|██▏ | 169M/788M [00:04<00:12, 52.9MB/s]
5%|▌ | 153M/2.89G [00:04<01:11, 41.0MB/s]
22%|██▏ | 177M/788M [00:04<00:11, 56.1MB/s]
5%|▌ | 161M/2.89G [00:04<01:07, 43.4MB/s]
6%|▌ | 169M/2.89G [00:04<01:02, 47.0MB/s]
23%|██▎ | 185M/788M [00:04<00:13, 46.5MB/s]
6%|▌ | 177M/2.89G [00:04<00:56, 51.9MB/s]
24%|██▍ | 193M/788M [00:04<00:14, 43.4MB/s]
26%|██▌ | 201M/788M [00:04<00:12, 47.7MB/s]
6%|▋ | 185M/2.89G [00:05<01:07, 43.3MB/s]
27%|██▋ | 210M/788M [00:05<00:11, 54.7MB/s]
7%|▋ | 193M/2.89G [00:05<01:01, 47.3MB/s]
28%|██▊ | 220M/788M [00:05<00:09, 63.2MB/s]
29%|██▉ | 229M/788M [00:05<00:08, 69.8MB/s]
30%|███ | 240M/788M [00:05<00:07, 80.0MB/s]
7%|▋ | 201M/2.89G [00:05<01:13, 39.1MB/s]
32%|███▏ | 249M/788M [00:05<00:08, 68.5MB/s]
7%|▋ | 209M/2.89G [00:05<01:19, 36.2MB/s]
33%|███▎ | 257M/788M [00:05<00:08, 67.2MB/s]
7%|▋ | 217M/2.89G [00:05<01:06, 43.0MB/s]
34%|███▍ | 270M/788M [00:05<00:06, 83.1MB/s]
36%|███▌ | 280M/788M [00:05<00:06, 84.7MB/s]
37%|███▋ | 290M/788M [00:05<00:05, 89.2MB/s]
8%|▊ | 225M/2.89G [00:06<01:13, 38.8MB/s]
38%|███▊ | 300M/788M [00:06<00:05, 89.5MB/s]
8%|▊ | 236M/2.89G [00:06<00:56, 50.8MB/s]
8%|▊ | 243M/2.89G [00:06<01:00, 47.0MB/s]
39%|███▉ | 309M/788M [00:06<00:07, 65.3MB/s]
8%|▊ | 249M/2.89G [00:06<00:57, 49.5MB/s]
41%|████ | 320M/788M [00:06<00:06, 75.8MB/s]
9%|▊ | 257M/2.89G [00:06<00:52, 53.7MB/s]
9%|▉ | 265M/2.89G [00:06<00:48, 58.1MB/s]
42%|████▏ | 329M/788M [00:06<00:08, 55.4MB/s]
9%|▉ | 276M/2.89G [00:07<00:42, 66.6MB/s]
43%|████▎ | 340M/788M [00:06<00:07, 66.4MB/s]
10%|▉ | 288M/2.89G [00:07<00:36, 77.6MB/s]
10%|█ | 296M/2.89G [00:07<00:36, 77.4MB/s]
44%|████▍ | 348M/788M [00:07<00:08, 52.9MB/s]
46%|████▌ | 360M/788M [00:07<00:06, 66.1MB/s]
10%|█ | 304M/2.89G [00:07<00:49, 56.3MB/s]
47%|████▋ | 369M/788M [00:07<00:06, 68.8MB/s]
48%|████▊ | 381M/788M [00:07<00:05, 78.4MB/s]
11%|█ | 312M/2.89G [00:07<00:55, 50.0MB/s]
11%|█ | 321M/2.89G [00:07<00:47, 57.9MB/s]
50%|████▉ | 393M/788M [00:07<00:05, 78.6MB/s]
51%|█████▏ | 405M/788M [00:07<00:04, 88.7MB/s]
11%|█ | 329M/2.89G [00:08<00:54, 50.4MB/s]
53%|█████▎ | 417M/788M [00:07<00:04, 86.4MB/s]
54%|█████▍ | 429M/788M [00:07<00:03, 94.4MB/s]
56%|█████▌ | 440M/788M [00:08<00:03, 96.7MB/s]
11%|█▏ | 337M/2.89G [00:08<01:11, 38.3MB/s]
57%|█████▋ | 450M/788M [00:08<00:03, 93.8MB/s]
12%|█▏ | 348M/2.89G [00:08<00:54, 49.9MB/s]
58%|█████▊ | 460M/788M [00:08<00:03, 93.4MB/s]
12%|█▏ | 355M/2.89G [00:08<00:54, 50.2MB/s]
60%|█████▉ | 470M/788M [00:08<00:03, 87.5MB/s]
61%|██████ | 481M/788M [00:08<00:04, 74.2MB/s]
12%|█▏ | 361M/2.89G [00:08<01:07, 40.5MB/s]
62%|██████▏ | 489M/788M [00:08<00:04, 75.5MB/s]
63%|██████▎ | 497M/788M [00:08<00:04, 66.8MB/s]
12%|█▏ | 369M/2.89G [00:09<01:18, 34.5MB/s]
64%|██████▍ | 508M/788M [00:09<00:03, 77.6MB/s]
13%|█▎ | 380M/2.89G [00:09<00:59, 45.5MB/s]
66%|██████▌ | 520M/788M [00:09<00:03, 86.4MB/s]
67%|██████▋ | 529M/788M [00:09<00:03, 68.2MB/s]
13%|█▎ | 386M/2.89G [00:09<01:14, 36.2MB/s]
69%|██████▊ | 540M/788M [00:09<00:03, 76.9MB/s]
70%|██████▉ | 551M/788M [00:09<00:02, 85.8MB/s]
13%|█▎ | 393M/2.89G [00:09<01:17, 34.7MB/s]
71%|███████ | 561M/788M [00:09<00:02, 88.3MB/s]
14%|█▎ | 401M/2.89G [00:10<01:07, 39.6MB/s]
73%|███████▎ | 572M/788M [00:09<00:02, 92.1MB/s]
14%|█▍ | 410M/2.89G [00:10<00:54, 48.8MB/s]
74%|███████▍ | 582M/788M [00:09<00:02, 94.0MB/s]
75%|███████▌ | 593M/788M [00:10<00:02, 88.0MB/s]
14%|█▍ | 417M/2.89G [00:10<01:02, 42.6MB/s]
77%|███████▋ | 605M/788M [00:10<00:01, 96.6MB/s]
14%|█▍ | 428M/2.89G [00:10<00:47, 55.4MB/s]
78%|███████▊ | 615M/788M [00:10<00:02, 83.9MB/s]
15%|█▍ | 438M/2.89G [00:10<00:40, 65.1MB/s]
79%|███████▉ | 624M/788M [00:10<00:02, 75.6MB/s]
15%|█▌ | 446M/2.89G [00:10<00:46, 56.8MB/s]
80%|████████ | 633M/788M [00:10<00:02, 60.1MB/s]
15%|█▌ | 457M/2.89G [00:11<00:51, 51.4MB/s]
82%|████████▏ | 643M/788M [00:10<00:02, 68.4MB/s]
16%|█▌ | 468M/2.89G [00:11<00:41, 62.8MB/s]
83%|████████▎ | 652M/788M [00:10<00:01, 73.0MB/s]
16%|█▌ | 476M/2.89G [00:11<00:44, 58.5MB/s]
84%|████████▎ | 660M/788M [00:11<00:02, 60.4MB/s]
16%|█▋ | 488M/2.89G [00:11<00:35, 72.0MB/s]
85%|████████▌ | 672M/788M [00:11<00:01, 73.5MB/s]
17%|█▋ | 497M/2.89G [00:11<00:38, 67.4MB/s]
87%|████████▋ | 683M/788M [00:11<00:01, 82.7MB/s]
17%|█▋ | 505M/2.89G [00:11<00:36, 70.2MB/s]
88%|████████▊ | 692M/788M [00:11<00:01, 70.4MB/s]
17%|█▋ | 513M/2.89G [00:11<00:38, 66.4MB/s]
89%|████████▉ | 704M/788M [00:11<00:01, 82.5MB/s]
18%|█▊ | 520M/2.89G [00:11<00:37, 67.9MB/s]
90%|█████████ | 713M/788M [00:11<00:00, 81.9MB/s]
18%|█▊ | 527M/2.89G [00:12<00:41, 62.1MB/s]
18%|█▊ | 536M/2.89G [00:12<00:37, 68.5MB/s]
92%|█████████▏| 722M/788M [00:11<00:01, 68.6MB/s]
93%|█████████▎| 733M/788M [00:12<00:00, 78.4MB/s]
94%|█████████▍| 742M/788M [00:12<00:00, 81.4MB/s]
18%|█▊ | 543M/2.89G [00:12<00:51, 49.4MB/s]
96%|█████████▌| 753M/788M [00:12<00:00, 87.1MB/s]
19%|█▊ | 552M/2.89G [00:12<00:43, 58.0MB/s]
97%|█████████▋| 765M/788M [00:12<00:00, 96.3MB/s]
19%|█▉ | 561M/2.89G [00:12<00:39, 63.4MB/s]
19%|█▉ | 573M/2.89G [00:12<00:32, 76.1MB/s]
98%|█████████▊| 775M/788M [00:12<00:00, 76.6MB/s]
100%|█████████▉| 787M/788M [00:12<00:00, 87.3MB/s]
100%|██████████| 788M/788M [00:12<00:00, 65.1MB/s]
20%|█▉ | 582M/2.89G [00:13<00:39, 62.3MB/s]
20%|█▉ | 590M/2.89G [00:13<00:37, 66.6MB/s]
20%|██ | 598M/2.89G [00:13<00:45, 54.2MB/s]
21%|██ | 609M/2.89G [00:13<00:38, 64.0MB/s]
21%|██ | 617M/2.89G [00:13<00:37, 65.9MB/s]
21%|██▏ | 629M/2.89G [00:13<00:30, 79.3MB/s]
22%|██▏ | 638M/2.89G [00:13<00:33, 71.9MB/s]
22%|██▏ | 646M/2.89G [00:13<00:35, 67.4MB/s]
22%|██▏ | 653M/2.89G [00:14<00:38, 63.5MB/s]
22%|██▏ | 664M/2.89G [00:14<00:31, 75.5MB/s]
23%|██▎ | 672M/2.89G [00:14<00:39, 60.0MB/s]
23%|██▎ | 681M/2.89G [00:14<00:45, 52.7MB/s]
23%|██▎ | 693M/2.89G [00:14<00:35, 66.3MB/s]
24%|██▎ | 701M/2.89G [00:14<00:42, 55.8MB/s]
24%|██▍ | 713M/2.89G [00:15<00:38, 61.1MB/s]
24%|██▍ | 721M/2.89G [00:15<00:39, 59.5MB/s]
25%|██▍ | 732M/2.89G [00:15<00:33, 69.0MB/s]
25%|██▌ | 744M/2.89G [00:15<00:29, 79.0MB/s]
25%|██▌ | 753M/2.89G [00:15<00:32, 71.3MB/s]
26%|██▌ | 761M/2.89G [00:15<00:36, 63.3MB/s]
26%|██▌ | 773M/2.89G [00:16<00:32, 71.4MB/s]
26%|██▋ | 784M/2.89G [00:16<00:28, 80.2MB/s]
27%|██▋ | 793M/2.89G [00:16<00:28, 79.1MB/s]
27%|██▋ | 801M/2.89G [00:16<00:36, 61.3MB/s]
27%|██▋ | 812M/2.89G [00:16<00:31, 72.3MB/s]
28%|██▊ | 820M/2.89G [00:16<00:33, 67.3MB/s]
28%|██▊ | 828M/2.89G [00:16<00:34, 65.2MB/s]
28%|██▊ | 836M/2.89G [00:16<00:32, 69.4MB/s]
29%|██▊ | 848M/2.89G [00:17<00:27, 81.0MB/s]
29%|██▉ | 857M/2.89G [00:17<00:35, 62.8MB/s]
29%|██▉ | 868M/2.89G [00:17<00:30, 71.3MB/s]
30%|██▉ | 876M/2.89G [00:17<00:35, 60.9MB/s]
30%|██▉ | 885M/2.89G [00:17<00:32, 66.1MB/s]
30%|███ | 897M/2.89G [00:17<00:27, 78.9MB/s]
31%|███ | 909M/2.89G [00:17<00:23, 89.8MB/s]
31%|███ | 919M/2.89G [00:18<00:23, 92.3MB/s]
31%|███▏ | 929M/2.89G [00:18<00:30, 70.8MB/s]
32%|███▏ | 940M/2.89G [00:18<00:27, 78.0MB/s]
32%|███▏ | 949M/2.89G [00:18<00:30, 68.9MB/s]
32%|███▏ | 958M/2.89G [00:18<00:28, 74.2MB/s]
33%|███▎ | 966M/2.89G [00:18<00:28, 73.8MB/s]
33%|███▎ | 974M/2.89G [00:18<00:32, 63.6MB/s]
33%|███▎ | 985M/2.89G [00:19<00:29, 71.0MB/s]
34%|███▎ | 996M/2.89G [00:19<00:25, 80.2MB/s]
34%|███▍ | 0.98G/2.89G [00:19<00:24, 83.4MB/s]
34%|███▍ | 0.99G/2.89G [00:19<00:23, 88.4MB/s]
35%|███▍ | 1.00G/2.89G [00:19<00:31, 65.3MB/s]
35%|███▌ | 1.01G/2.89G [00:19<00:27, 72.8MB/s]
35%|███▌ | 1.02G/2.89G [00:19<00:29, 67.7MB/s]
36%|███▌ | 1.03G/2.89G [00:20<00:29, 67.6MB/s]
36%|███▌ | 1.03G/2.89G [00:20<00:34, 58.6MB/s]
36%|███▌ | 1.04G/2.89G [00:20<00:30, 64.9MB/s]
36%|███▋ | 1.05G/2.89G [00:20<00:44, 44.1MB/s]
37%|███▋ | 1.06G/2.89G [00:20<00:37, 53.1MB/s]
37%|███▋ | 1.07G/2.89G [00:20<00:29, 65.2MB/s]
37%|███▋ | 1.08G/2.89G [00:21<00:34, 55.7MB/s]
38%|███▊ | 1.09G/2.89G [00:21<00:29, 66.0MB/s]
38%|███▊ | 1.10G/2.89G [00:21<00:33, 58.3MB/s]
38%|███▊ | 1.10G/2.89G [00:21<00:33, 57.2MB/s]
38%|███▊ | 1.11G/2.89G [00:21<00:35, 53.1MB/s]
39%|███▉ | 1.12G/2.89G [00:21<00:28, 65.8MB/s]
39%|███▉ | 1.13G/2.89G [00:22<00:34, 55.3MB/s]
39%|███▉ | 1.14G/2.89G [00:22<00:27, 67.2MB/s]
40%|███▉ | 1.15G/2.89G [00:22<00:30, 62.2MB/s]
40%|████ | 1.16G/2.89G [00:22<00:41, 44.3MB/s]
40%|████ | 1.17G/2.89G [00:22<00:33, 55.8MB/s]
41%|████ | 1.18G/2.89G [00:22<00:27, 66.1MB/s]
41%|████ | 1.19G/2.89G [00:23<00:29, 62.9MB/s]
41%|████▏ | 1.20G/2.89G [00:23<00:27, 65.7MB/s]
42%|████▏ | 1.20G/2.89G [00:23<00:30, 58.5MB/s]
42%|████▏ | 1.21G/2.89G [00:23<00:30, 59.2MB/s]
42%|████▏ | 1.22G/2.89G [00:23<00:32, 54.8MB/s]
42%|████▏ | 1.23G/2.89G [00:23<00:34, 51.4MB/s]
43%|████▎ | 1.24G/2.89G [00:23<00:28, 62.4MB/s]
43%|████▎ | 1.25G/2.89G [00:24<00:27, 64.3MB/s]
43%|████▎ | 1.25G/2.89G [00:24<00:28, 62.7MB/s]
44%|████▎ | 1.26G/2.89G [00:24<00:23, 74.6MB/s]
44%|████▍ | 1.27G/2.89G [00:24<00:23, 74.3MB/s]
44%|████▍ | 1.28G/2.89G [00:24<00:30, 56.4MB/s]
45%|████▍ | 1.29G/2.89G [00:24<00:25, 68.6MB/s]
45%|████▍ | 1.30G/2.89G [00:25<00:32, 52.8MB/s]
45%|████▌ | 1.31G/2.89G [00:25<00:35, 47.3MB/s]
46%|████▌ | 1.32G/2.89G [00:25<00:30, 55.2MB/s]
46%|████▌ | 1.32G/2.89G [00:25<00:28, 58.5MB/s]
46%|████▌ | 1.33G/2.89G [00:25<00:24, 68.2MB/s]
46%|████▋ | 1.34G/2.89G [00:25<00:26, 63.9MB/s]
47%|████▋ | 1.35G/2.89G [00:25<00:23, 70.4MB/s]
47%|████▋ | 1.36G/2.89G [00:25<00:21, 75.4MB/s]
47%|████▋ | 1.37G/2.89G [00:26<00:25, 63.7MB/s]
48%|████▊ | 1.38G/2.89G [00:26<00:26, 60.4MB/s]
48%|████▊ | 1.39G/2.89G [00:26<00:23, 69.4MB/s]
48%|████▊ | 1.39G/2.89G [00:26<00:23, 68.4MB/s]
49%|████▊ | 1.40G/2.89G [00:26<00:20, 76.9MB/s]
49%|████▉ | 1.41G/2.89G [00:26<00:18, 85.7MB/s]
49%|████▉ | 1.42G/2.89G [00:26<00:18, 86.0MB/s]
50%|████▉ | 1.43G/2.89G [00:27<00:19, 78.3MB/s]
50%|████▉ | 1.44G/2.89G [00:27<00:18, 84.2MB/s]
50%|█████ | 1.45G/2.89G [00:27<00:24, 63.4MB/s]
51%|█████ | 1.46G/2.89G [00:27<00:33, 46.0MB/s]
51%|█████ | 1.47G/2.89G [00:27<00:26, 56.9MB/s]
51%|█████ | 1.48G/2.89G [00:28<00:26, 57.2MB/s]
51%|█████▏ | 1.49G/2.89G [00:28<00:31, 47.3MB/s]
52%|█████▏ | 1.50G/2.89G [00:28<00:27, 55.2MB/s]
52%|█████▏ | 1.50G/2.89G [00:28<00:26, 56.2MB/s]
52%|█████▏ | 1.51G/2.89G [00:28<00:23, 64.1MB/s]
53%|█████▎ | 1.52G/2.89G [00:28<00:20, 72.4MB/s]
53%|█████▎ | 1.53G/2.89G [00:28<00:19, 73.5MB/s]
53%|█████▎ | 1.54G/2.89G [00:28<00:18, 76.5MB/s]
54%|█████▎ | 1.55G/2.89G [00:29<00:19, 73.3MB/s]
54%|█████▍ | 1.56G/2.89G [00:29<00:17, 82.4MB/s]
54%|█████▍ | 1.57G/2.89G [00:29<00:18, 75.5MB/s]
55%|█████▍ | 1.58G/2.89G [00:29<00:17, 79.1MB/s]
55%|█████▍ | 1.58G/2.89G [00:29<00:19, 72.6MB/s]
55%|█████▌ | 1.59G/2.89G [00:29<00:22, 61.2MB/s]
55%|█████▌ | 1.60G/2.89G [00:29<00:21, 64.8MB/s]
56%|█████▌ | 1.61G/2.89G [00:30<00:19, 69.6MB/s]
56%|█████▌ | 1.62G/2.89G [00:30<00:19, 69.8MB/s]
56%|█████▋ | 1.63G/2.89G [00:30<00:22, 59.4MB/s]
57%|█████▋ | 1.64G/2.89G [00:30<00:19, 69.8MB/s]
57%|█████▋ | 1.64G/2.89G [00:30<00:21, 61.1MB/s]
57%|█████▋ | 1.65G/2.89G [00:30<00:19, 69.0MB/s]
58%|█████▊ | 1.66G/2.89G [00:30<00:21, 60.5MB/s]
58%|█████▊ | 1.67G/2.89G [00:31<00:18, 68.8MB/s]
58%|█████▊ | 1.68G/2.89G [00:31<00:22, 56.6MB/s]
58%|█████▊ | 1.69G/2.89G [00:31<00:22, 57.5MB/s]
59%|█████▊ | 1.69G/2.89G [00:31<00:22, 56.2MB/s]
59%|█████▉ | 1.70G/2.89G [00:31<00:20, 61.2MB/s]
59%|█████▉ | 1.71G/2.89G [00:31<00:18, 68.8MB/s]
60%|█████▉ | 1.72G/2.89G [00:31<00:17, 72.2MB/s]
60%|█████▉ | 1.73G/2.89G [00:32<00:17, 72.2MB/s]
60%|██████ | 1.74G/2.89G [00:32<00:16, 75.5MB/s]
60%|██████ | 1.74G/2.89G [00:32<00:16, 76.7MB/s]
61%|██████ | 1.75G/2.89G [00:32<00:21, 57.6MB/s]
61%|██████ | 1.76G/2.89G [00:32<00:18, 66.2MB/s]
61%|██████ | 1.77G/2.89G [00:32<00:17, 69.0MB/s]
62%|██████▏ | 1.78G/2.89G [00:32<00:16, 70.8MB/s]
62%|██████▏ | 1.79G/2.89G [00:32<00:16, 73.4MB/s]
62%|██████▏ | 1.80G/2.89G [00:33<00:13, 84.3MB/s]
62%|██████▏ | 1.80G/2.89G [00:33<00:18, 64.5MB/s]
63%|██████▎ | 1.82G/2.89G [00:33<00:15, 75.6MB/s]
63%|██████▎ | 1.82G/2.89G [00:33<00:15, 75.0MB/s]
63%|██████▎ | 1.83G/2.89G [00:33<00:15, 73.5MB/s]
64%|██████▍ | 1.84G/2.89G [00:33<00:13, 83.2MB/s]
64%|██████▍ | 1.85G/2.89G [00:33<00:15, 72.0MB/s]
64%|██████▍ | 1.86G/2.89G [00:34<00:23, 47.2MB/s]
65%|██████▍ | 1.87G/2.89G [00:34<00:18, 58.6MB/s]
65%|██████▌ | 1.88G/2.89G [00:34<00:17, 61.5MB/s]
65%|██████▌ | 1.89G/2.89G [00:34<00:18, 58.4MB/s]
66%|██████▌ | 1.89G/2.89G [00:34<00:17, 60.0MB/s]
66%|██████▌ | 1.91G/2.89G [00:34<00:14, 71.6MB/s]
66%|██████▋ | 1.92G/2.89G [00:34<00:14, 70.7MB/s]
67%|██████▋ | 1.92G/2.89G [00:35<00:13, 75.2MB/s]
67%|██████▋ | 1.93G/2.89G [00:35<00:17, 59.6MB/s]
67%|██████▋ | 1.94G/2.89G [00:35<00:14, 68.7MB/s]
67%|██████▋ | 1.95G/2.89G [00:35<00:16, 62.5MB/s]
68%|██████▊ | 1.96G/2.89G [00:35<00:13, 74.0MB/s]
68%|██████▊ | 1.97G/2.89G [00:35<00:12, 78.0MB/s]
68%|██████▊ | 1.98G/2.89G [00:35<00:12, 79.5MB/s]
69%|██████▊ | 1.99G/2.89G [00:36<00:14, 66.5MB/s]
69%|██████▉ | 1.99G/2.89G [00:36<00:15, 63.2MB/s]
69%|██████▉ | 2.00G/2.89G [00:36<00:13, 71.9MB/s]
70%|██████▉ | 2.01G/2.89G [00:36<00:16, 57.4MB/s]
70%|██████▉ | 2.02G/2.89G [00:36<00:22, 41.6MB/s]
70%|███████ | 2.03G/2.89G [00:37<00:17, 51.9MB/s]
70%|███████ | 2.04G/2.89G [00:37<00:17, 53.3MB/s]
71%|███████ | 2.05G/2.89G [00:37<00:13, 65.9MB/s]
71%|███████ | 2.05G/2.89G [00:37<00:13, 64.5MB/s]
71%|███████▏ | 2.06G/2.89G [00:37<00:13, 66.2MB/s]
72%|███████▏ | 2.07G/2.89G [00:37<00:11, 74.8MB/s]
72%|███████▏ | 2.08G/2.89G [00:37<00:14, 58.0MB/s]
72%|███████▏ | 2.09G/2.89G [00:38<00:19, 43.1MB/s]
73%|███████▎ | 2.10G/2.89G [00:38<00:15, 54.2MB/s]
73%|███████▎ | 2.11G/2.89G [00:38<00:15, 54.8MB/s]
73%|███████▎ | 2.11G/2.89G [00:38<00:15, 53.3MB/s]
73%|███████▎ | 2.12G/2.89G [00:38<00:20, 39.6MB/s]
74%|███████▎ | 2.13G/2.89G [00:39<00:22, 36.5MB/s]
74%|███████▍ | 2.13G/2.89G [00:39<00:24, 33.5MB/s]
74%|███████▍ | 2.14G/2.89G [00:39<00:21, 38.1MB/s]
74%|███████▍ | 2.15G/2.89G [00:39<00:22, 34.8MB/s]
75%|███████▍ | 2.16G/2.89G [00:40<00:22, 35.5MB/s]
75%|███████▌ | 2.17G/2.89G [00:40<00:16, 47.2MB/s]
75%|███████▌ | 2.17G/2.89G [00:40<00:14, 51.8MB/s]
76%|███████▌ | 2.19G/2.89G [00:40<00:11, 64.1MB/s]
76%|███████▌ | 2.19G/2.89G [00:40<00:11, 62.6MB/s]
76%|███████▌ | 2.20G/2.89G [00:40<00:11, 67.0MB/s]
77%|███████▋ | 2.21G/2.89G [00:40<00:09, 75.2MB/s]
77%|███████▋ | 2.22G/2.89G [00:40<00:12, 58.2MB/s]
77%|███████▋ | 2.23G/2.89G [00:41<00:15, 46.2MB/s]
77%|███████▋ | 2.24G/2.89G [00:41<00:12, 58.2MB/s]
78%|███████▊ | 2.25G/2.89G [00:41<00:11, 59.8MB/s]
78%|███████▊ | 2.25G/2.89G [00:41<00:10, 66.9MB/s]
78%|███████▊ | 2.26G/2.89G [00:41<00:10, 64.9MB/s]
79%|███████▊ | 2.27G/2.89G [00:41<00:09, 71.5MB/s]
79%|███████▉ | 2.28G/2.89G [00:42<00:12, 51.0MB/s]
79%|███████▉ | 2.29G/2.89G [00:42<00:10, 60.5MB/s]
80%|███████▉ | 2.30G/2.89G [00:42<00:11, 55.9MB/s]
80%|███████▉ | 2.31G/2.89G [00:42<00:09, 67.2MB/s]
80%|████████ | 2.32G/2.89G [00:42<00:10, 59.8MB/s]
81%|████████ | 2.33G/2.89G [00:42<00:08, 71.2MB/s]
81%|████████ | 2.34G/2.89G [00:42<00:08, 69.8MB/s]
81%|████████ | 2.34G/2.89G [00:43<00:10, 57.4MB/s]
82%|████████▏ | 2.36G/2.89G [00:43<00:08, 68.4MB/s]
82%|████████▏ | 2.36G/2.89G [00:43<00:08, 68.4MB/s]
82%|████████▏ | 2.37G/2.89G [00:43<00:07, 76.8MB/s]
82%|████████▏ | 2.38G/2.89G [00:43<00:07, 72.8MB/s]
83%|████████▎ | 2.39G/2.89G [00:43<00:07, 74.9MB/s]
83%|████████▎ | 2.40G/2.89G [00:43<00:08, 65.5MB/s]
83%|████████▎ | 2.41G/2.89G [00:44<00:06, 76.4MB/s]
84%|████████▎ | 2.42G/2.89G [00:44<00:08, 58.7MB/s]
84%|████████▍ | 2.43G/2.89G [00:44<00:07, 70.3MB/s]
84%|████████▍ | 2.44G/2.89G [00:44<00:09, 52.4MB/s]
85%|████████▍ | 2.45G/2.89G [00:45<00:10, 44.1MB/s]
85%|████████▌ | 2.46G/2.89G [00:45<00:08, 54.7MB/s]
85%|████████▌ | 2.47G/2.89G [00:45<00:07, 61.7MB/s]
86%|████████▌ | 2.48G/2.89G [00:45<00:06, 72.3MB/s]
86%|████████▌ | 2.49G/2.89G [00:45<00:05, 73.5MB/s]
86%|████████▋ | 2.49G/2.89G [00:45<00:05, 76.0MB/s]
87%|████████▋ | 2.50G/2.89G [00:45<00:05, 76.3MB/s]
87%|████████▋ | 2.51G/2.89G [00:45<00:05, 78.1MB/s]
87%|████████▋ | 2.52G/2.89G [00:45<00:04, 83.3MB/s]
88%|████████▊ | 2.53G/2.89G [00:46<00:05, 72.4MB/s]
88%|████████▊ | 2.54G/2.89G [00:46<00:04, 82.2MB/s]
88%|████████▊ | 2.55G/2.89G [00:46<00:04, 82.0MB/s]
89%|████████▊ | 2.56G/2.89G [00:46<00:03, 89.4MB/s]
89%|████████▉ | 2.57G/2.89G [00:46<00:03, 95.4MB/s]
89%|████████▉ | 2.58G/2.89G [00:46<00:03, 96.1MB/s]
90%|████████▉ | 2.59G/2.89G [00:46<00:03, 101MB/s]
90%|█████████ | 2.60G/2.89G [00:46<00:03, 91.6MB/s]
90%|█████████ | 2.61G/2.89G [00:46<00:03, 97.6MB/s]
91%|█████████ | 2.62G/2.89G [00:47<00:02, 102MB/s]
91%|█████████ | 2.63G/2.89G [00:47<00:02, 97.9MB/s]
92%|█████████▏| 2.64G/2.89G [00:47<00:02, 102MB/s]
92%|█████████▏| 2.65G/2.89G [00:47<00:02, 99.0MB/s]
92%|█████████▏| 2.66G/2.89G [00:47<00:02, 100MB/s]
93%|█████████▎| 2.67G/2.89G [00:47<00:02, 97.4MB/s]
93%|█████████▎| 2.68G/2.89G [00:47<00:02, 96.0MB/s]
93%|█████████▎| 2.69G/2.89G [00:47<00:02, 93.4MB/s]
94%|█████████▎| 2.70G/2.89G [00:47<00:02, 95.1MB/s]
94%|█████████▍| 2.71G/2.89G [00:48<00:01, 100MB/s]
94%|█████████▍| 2.72G/2.89G [00:48<00:01, 95.1MB/s]
95%|█████████▍| 2.73G/2.89G [00:48<00:01, 99.4MB/s]
95%|█████████▌| 2.75G/2.89G [00:48<00:01, 98.1MB/s]
95%|█████████▌| 2.75G/2.89G [00:48<00:01, 85.6MB/s]
96%|█████████▌| 2.76G/2.89G [00:48<00:01, 73.1MB/s]
96%|█████████▌| 2.77G/2.89G [00:48<00:01, 80.1MB/s]
96%|█████████▋| 2.78G/2.89G [00:49<00:01, 76.6MB/s]
97%|█████████▋| 2.79G/2.89G [00:49<00:01, 85.2MB/s]
97%|█████████▋| 2.80G/2.89G [00:49<00:00, 92.3MB/s]
97%|█████████▋| 2.81G/2.89G [00:49<00:00, 82.8MB/s]
98%|█████████▊| 2.82G/2.89G [00:49<00:00, 89.8MB/s]
98%|█████████▊| 2.83G/2.89G [00:49<00:00, 93.7MB/s]
98%|█████████▊| 2.84G/2.89G [00:49<00:00, 70.5MB/s]
99%|█████████▊| 2.85G/2.89G [00:49<00:00, 65.1MB/s]
99%|█████████▉| 2.86G/2.89G [00:50<00:00, 58.7MB/s]
99%|█████████▉| 2.87G/2.89G [00:50<00:00, 68.1MB/s]
100%|█████████▉| 2.88G/2.89G [00:50<00:00, 73.6MB/s]
100%|██████████| 2.89G/2.89G [00:50<00:00, 82.8MB/s]
100%|██████████| 2.89G/2.89G [00:50<00:00, 61.5MB/s]
We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax
:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
lambda m: orbax.checkpoint.ArrayRestoreArgs(
restore_type=jax.Array,
sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
),
metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)
Let’s take a look! Since we’ve registered Treescope as the default pretty-printer and turned on array visualization, we can just output the arrays from Colab and see a rich visualization of their values.
Try clicking to explore the structure of the arrays below!
(Note: It may take a while for the array summaries to load the first time, because JAX has to compile the summarization code. You can still look at array shapes before they finish, and it should be faster to run the second time.)
flat_params