Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

LoRA From Scratch - Patching Pretrained Models in Penzai#

Penzai is designed to make it easy to make targeted modifications to neural networks after they have been trained. In this notebook, we’ll show how to take Penzai’s reference implementation of Gemma 7B open-weights transformer model, patch it to support Low-Rank Adaptation (LoRA Hu et al. 2021), and train the new parameters on a toy problem with a hand-written loss function.

The goal of this notebook is to show how you could implement something like LoRA from scratch in less than a hundred lines of code, starting from a Penzai implementation of a model that doesn’t support it already, and without having to fork the existing implementation source code or even modify the pretrained model’s configuration. We’ll define everything we need as we go and make changes to models interactively. In fact, our implementation will end up being completely modular; we’ll start by applying LoRA to a small MLP and then immediately be able to transfer our implementation to Gemma 7B.

Let’s get started!

Setup#

Before we can get started in earnest, we need to set up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import os
import gc
import collections

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils
import penzai
from penzai import pz
import sentencepiece as spm
from penzai.example_models import gemma
from penzai.example_models import simple_mlp
from penzai.toolshed import basic_training
from penzai.toolshed import token_visualization
from penzai.toolshed import jit_wrapper

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

Intro to Penzai’s declarative combinator design#

We’ll start by giving a brief introduction to Penzai’s design conventions, and how they make it easy to insert adapters into pretrained models. Let’s begin by initializing a small MLP:

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([8, 32, 32, 8]),
    jax.random.key(0),
)

Like most Penzai models and layers, this MLP takes named arrays as input and returns them as output. A named array is just a wrapped JAX array where a subset of its positional axes have been tagged with names. (See the named axes tutorial for more info on how to use Penzai’s named axis system.)

We can call the MLP directly on an array of inputs to run it:

%%autovisualize
mlp(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

Penzai models are written in a declarative, combinator-based style. This means that the structure of the model directly matches the sequence of high-level operations that the model will run in its forward pass. Composite models, like our MLP, just hold onto their sublayers in a list and run these sublayers in order. Primitive layers, like Linear, hold on to their parameters as attributes instead of reading them from an external parameter dictionary.

We can see the sublayers by pretty-printing the model:

%%autovisualize
mlp

By convention, most of the “complicated” logic in Penzai model classes happens when we initialize them, using the .from_config method we called earlier. Once the model is built, the pretty-printed representation provides a full specification of everything the model does, and the parameters are stored as direct attributes on the layers that need them. A general design principle of Penzai is “what you see is what you get”; you should be able to learn everything you need to know about a model by printing it out.

In fact, you can click on a pretty-printed output, press r to add qualified names to the pretty-printed visualization (try it above!), and then copy and paste the entire pretty-printed code and execute it to make a copy of the model:

penzai.example_models.simple_mlp.MLP(
  sublayers=[
    penzai.nn.linear_and_affine.Affine(
    sublayers=[
        penzai.nn.linear_and_affine.LinearInPlace(
          sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([[ 0.04725012, -0.23066601, -0.33130226,  0.14961348,  0.0566796 ,\n        -0.00751991,  0.38097695, -0.19894585,  0.10838961,  0.03168617,\n         0.18273951, -0.19517657,  0.10663664,  0.03312971, -0.18675321,\n        -0.18642329,  0.1921328 ,  0.10855018, -0.02933162,  0.28482816,\n         0.2975921 , -0.02662617,  0.2734555 , -0.19869885,  0.24985315,\n         0.05980473,  0.20454031, -0.14542156, -0.12331318, -0.09077168,\n        -0.08418605, -0.11115803],\n       [-0.06492946, -0.11429147,  0.30805257, -0.32192293,  0.3667109 ,\n         0.27841014, -0.24174255, -0.24910982, -0.27875087,  0.35143396,\n         0.07689688, -0.01147286, -0.3262384 ,  0.02630419, -0.3234371 ,\n        -0.27148998, -0.3160725 , -0.36660823,  0.34744027,  0.12888198,\n         0.23805176, -0.24385388,  0.2650048 ,  0.03819488, -0.00568975,\n        -0.0215622 ,  0.2193515 , -0.3253995 ,  0.09929291, -0.29280028,\n        -0.26032737, -0.20181446],\n       [ 0.01659251,  0.36555287, -0.38305447,  0.18675229,  0.03717915,\n         0.12696777, -0.25671378,  0.17499834, -0.22098203, -0.15922017,\n         0.2351897 ,  0.32522848,  0.07117841,  0.33022884, -0.06571785,\n         0.13955157, -0.0600304 , -0.19759853,  0.3037875 ,  0.30708078,\n        -0.07626879,  0.35707024,  0.19575489,  0.15175632,  0.10577198,\n         0.00364989,  0.18226433,  0.08367107,  0.29136872, -0.18551245,\n        -0.14171803, -0.08811028],\n       [ 0.38122043,  0.23902583,  0.17216133,  0.06905378,  0.16787706,\n        -0.08580669, -0.22893907, -0.34190187, -0.28671974,  0.23940544,\n        -0.16707611, -0.08652407, -0.23550698,  0.1630661 ,  0.11646476,\n         0.03909842, -0.14891088, -0.21910995,  0.18650232, -0.03763161,\n         0.26505992, -0.37263677,  0.01191313, -0.24201725, -0.04657881,\n        -0.36401856,  0.3515382 ,  0.2873264 ,  0.08209115,  0.14519285,\n         0.31111813, -0.00785317],\n       [-0.08459742,  0.25045004, -0.28495365,  0.33237663, -0.11983901,\n         0.27122143,  0.36026204, -0.24545845,  0.10550512,  0.30733794,\n         0.11679422,  0.21981671,  0.1890508 , -0.36632678,  0.1600768 ,\n         0.09827441,  0.08301278, -0.2713667 ,  0.36596537, -0.19401632,\n         0.08262127, -0.32878917,  0.33614868,  0.0187863 , -0.08504166,\n         0.20853564, -0.37342063,  0.04217026,  0.15900622,  0.05655929,\n        -0.3844944 , -0.24541588],\n       [ 0.29031137, -0.3365837 , -0.14176023, -0.15418899, -0.27481785,\n        -0.2589707 , -0.0712328 , -0.03633535, -0.02502658, -0.1892304 ,\n         0.05823506,  0.06479566,  0.3832009 ,  0.08127026, -0.08600669,\n        -0.1306198 ,  0.3590402 , -0.26903337, -0.16024597, -0.28321084,\n        -0.22990172, -0.3255212 ,  0.04306457, -0.15599836,  0.13166516,\n         0.18247025, -0.35739547,  0.23415357,  0.27890775, -0.04992555,\n        -0.14944552, -0.31249046],\n       [ 0.36270365, -0.35655767, -0.11624786,  0.07357443, -0.20815234,\n        -0.32495487, -0.32694384,  0.11253693,  0.34765404,  0.3648101 ,\n        -0.24689904,  0.06159897, -0.1283295 ,  0.19177036, -0.13034305,\n        -0.04315192, -0.18603453,  0.2254531 ,  0.12818351, -0.01486595,\n        -0.11669718, -0.27403483,  0.14894357, -0.09331072, -0.26205596,\n         0.18480024,  0.20129958, -0.13770516,  0.37675482, -0.31014258,\n        -0.37200913,  0.22289428],\n       [ 0.16453928, -0.21428025, -0.05870866,  0.19057752, -0.31916317,\n         0.0421687 ,  0.12571178, -0.12293893,  0.08633191, -0.38240284,\n         0.38143364, -0.3229353 ,  0.22557646,  0.05067461, -0.20315255,\n         0.24650854, -0.36949688, -0.01077681,  0.14819378, -0.22307158,\n         0.205777  ,  0.23953518,  0.1715395 , -0.23309591,  0.10060634,\n        -0.34637597,  0.11265402, -0.18391618, -0.03978348,  0.01059979,\n         0.09037849,  0.05671118]], dtype=float32)', original_id=20251670234640, original_type=jax.Array)), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',))],
        ),
        penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)', original_id=20251431144464, original_type=jax.Array)), name='Affine_0.AddBias.bias'), new_axis_names=()),
      ],
    ),
    penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.nn.linear_and_affine.Affine(
      sublayers=[penzai.nn.linear_and_affine.LinearInPlace(sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([[-0.14293982, -0.12958716,  0.1908599 , ...,  0.05771805,\n         0.20101726, -0.21264127],\n       [ 0.29049423,  0.04062944, -0.29115948, ..., -0.29757395,\n         0.25379792, -0.17268081],\n       [ 0.1237077 , -0.01195317,  0.26793247, ...,  0.09746177,\n         0.18667588, -0.2871105 ],\n       ...,\n       [-0.30499855,  0.21407223,  0.12929553, ...,  0.2603329 ,\n         0.12992881, -0.02856655],\n       [-0.13158074, -0.10787383,  0.27249086, ...,  0.2223304 ,\n         0.00778098,  0.10102441],\n       [-0.24134454, -0.0659571 ,  0.09705041, ...,  0.20444332,\n        -0.01505233, -0.09814899]], dtype=float32)', original_id=20251416478480, original_type=jax.Array)), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',))]), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)', original_id=20251219338000, original_type=jax.Array)), name='Affine_1.AddBias.bias'), new_axis_names=())],
    ),
    penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.nn.linear_and_affine.Affine(
      sublayers=[penzai.nn.linear_and_affine.LinearInPlace(sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([[-0.02120503, -0.26273215,  0.19671364, -0.25535676,  0.33353078,\n         0.34760946,  0.11154714, -0.36327165],\n       [-0.18882798,  0.29033083, -0.01784066, -0.35918185, -0.1183674 ,\n        -0.09689117, -0.22195095,  0.17173886],\n       [-0.17731062, -0.30207223, -0.25690556,  0.13724743,  0.18264246,\n         0.14667968, -0.09792759,  0.28671548],\n       [ 0.13883723, -0.36321005,  0.3602296 ,  0.29703382, -0.14947332,\n         0.38648838,  0.33616605, -0.2456858 ],\n       [ 0.29338026,  0.29596722, -0.26344863,  0.19266097,  0.01534889,\n        -0.21687248, -0.37927428, -0.15640356],\n       [-0.26717016,  0.19525719, -0.11175556, -0.21094975, -0.36075762,\n        -0.06405002,  0.28700718,  0.02164715],\n       [ 0.12978551, -0.11752943, -0.12858906,  0.21853136,  0.09607212,\n         0.31332853, -0.01224952, -0.28824407],\n       [-0.10415927, -0.0974042 ,  0.05993234, -0.3562885 ,  0.23716317,\n        -0.37986112, -0.24494588, -0.29328707],\n       [ 0.04717422, -0.04239095,  0.27582565, -0.37281165,  0.00247967,\n        -0.23975734, -0.0299767 ,  0.16125403],\n       [-0.07960048,  0.3762669 , -0.36568752,  0.23300837,  0.10682021,\n         0.17286457, -0.3787364 ,  0.24803638],\n       [-0.3089405 ,  0.10941845,  0.03270726, -0.3604047 ,  0.05435884,\n         0.07881735, -0.22836307,  0.33140016],\n       [-0.20255178, -0.37571028,  0.20740984, -0.02714456,  0.08295553,\n        -0.33837444,  0.26255652,  0.36039695],\n       [ 0.02880464,  0.1907816 ,  0.19101946,  0.05349049,  0.13013566,\n         0.33623788,  0.36152375,  0.26223242],\n       [-0.12129059, -0.05046343, -0.22875458, -0.22711188, -0.28795338,\n        -0.0417299 , -0.12364311, -0.20053242],\n       [ 0.15858358, -0.21735588,  0.2747909 ,  0.00585236, -0.2112107 ,\n        -0.26138678,  0.21196207,  0.13825993],\n       [ 0.02330455,  0.22529002,  0.11926983, -0.02966321,  0.14428146,\n         0.11442037, -0.33265772,  0.23587993],\n       [-0.37871405,  0.28270426, -0.21009018,  0.06479418, -0.36493284,\n        -0.3663771 , -0.3676921 , -0.17661622],\n       [ 0.06361667,  0.21918733,  0.24235171, -0.06552726,  0.27533627,\n         0.2639572 , -0.02761605,  0.34604597],\n       [-0.34795552, -0.02057168, -0.00230321, -0.20056584, -0.29356706,\n         0.37752953, -0.34542397,  0.35203296],\n       [-0.07879667, -0.03717786,  0.01729226, -0.15346181, -0.36752957,\n        -0.23917256, -0.3446656 ,  0.19485265],\n       [ 0.2728985 ,  0.24536206, -0.24228264, -0.06692472,  0.34564835,\n        -0.10513262,  0.2722488 , -0.3408531 ],\n       [-0.2712835 ,  0.15825975,  0.32759133, -0.20761152,  0.06780361,\n         0.20713052,  0.02732306, -0.15797636],\n       [ 0.02223683,  0.10032009, -0.05367341,  0.20788151, -0.08944707,\n        -0.37642628,  0.16834743,  0.17841148],\n       [ 0.03506837, -0.37967107,  0.06165724,  0.16286056, -0.27745268,\n        -0.27329233, -0.12189328,  0.36668006],\n       [-0.37223703, -0.00192019, -0.03584365,  0.31333244,  0.12743935,\n        -0.27959207, -0.18257874, -0.3266675 ],\n       [-0.17420247, -0.18477465,  0.1551072 , -0.03858141,  0.3605494 ,\n         0.16175184, -0.10405576, -0.25369927],\n       [ 0.19931318,  0.1297951 , -0.28167292, -0.04960836,  0.22335662,\n         0.14375217,  0.11561006,  0.14631467],\n       [ 0.3007057 , -0.30685687,  0.12430204, -0.33081028,  0.07697242,\n         0.00636965,  0.1470099 ,  0.33340123],\n       [-0.26456222,  0.19320115, -0.36342096, -0.36083212, -0.3115429 ,\n        -0.3427212 , -0.10350902, -0.35717598],\n       [ 0.12212099,  0.27373436,  0.03613165,  0.30182862,  0.12713408,\n         0.26485837, -0.27315864,  0.29892513],\n       [ 0.07317165,  0.26054114,  0.2656967 ,  0.05795342,  0.34420094,\n         0.15277314,  0.255256  ,  0.3382493 ],\n       [-0.09600905,  0.3696218 ,  0.36078495, -0.3213724 , -0.1559619 ,\n         0.1136618 , -0.11645275, -0.38116798]], dtype=float32)', original_id=20251403998480, original_type=jax.Array)), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',))]), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)', original_id=20251670240016, original_type=jax.Array)), name='Affine_2.AddBias.bias'), new_axis_names=())],
    ),
  ],
)

We won’t usually do this in practice, because device arrays can’t be copy-pasted this way; the parameters will be replaced with placeholder objects. Instead, Penzai provides a sophisticated selector system (pz.select) that allow us to make targeted modifications to (copies of) models. The point here is that Penzai model objects aren’t “hiding” anything; they directly expose the structure of their computation as a data structure that can be manipulated.

The specific types of Penzai models and composite layers are provided primarily for ease of manipulation and as a way to identify how each part of your model was built. But you can also “flatten” a model into a list of sublayers that run in sequence, discarding this extra information:

pz.nn.inline_groups(mlp, lambda _: True, lambda _: True)

And you can freely add new logic as well, even if it wasn’t configured in the initial model. For instance, here’s how you could insert a new layer that prints out its intermediate activation:

@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.Layer):  # <- pz.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value):
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value
patched = (
    pz.select(mlp)
    .at(lambda model: model.sublayers[2])
    .insert_after(DisplayIntermediateValue())
)
pz.select(patched).at_instances_of(DisplayIntermediateValue).show_selection()

patched is a copy of our model that includes our new layer, and it will run our new logic when the model is called:

%%autovisualize
patched(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

This ability makes it remarkably easy to implement adapters like LoRA!

Building a simple LoRA Layer in Penzai#

Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning strategy that augments each linear operation in the model with a decomposed low-rank adapter. The original weight matrix is frozen, and two smaller learnable parameter matrices are used to perturb its output. These parameters are kept separate from the original matrix, so gradients of these new parameters can be easily updated in a compute- and memory-efficient way.

The effective weight matrix can be decomposed like this:

 ┌────────────────┐       ┌─────┐                      
 │                │       │     │                      
 │                │       │  A: │   ┌────────────────┐
 │    W: d*d      │   +   │ d*r │ * │     B: r*d     │
 │                │       │     │   └────────────────┘
 │                │       │     │                      
 └────────────────┘       └─────┘                      

Here W is the original frozen weight matrix, A is a randomly-initialized matrix, and B is initialized to zero to ensure that the adapted model is equivalent to the original one at initialization.

To enable LoRA, we’ll do three things for each linear layer in our model:

  • Freeze the original weight,

  • Initialize our low-rank matrices A and B,

  • And replace the original linear layer with the composition of W, A, and B.

Let’s try it out with a simple MLP like the one we built in the last section. We’ll just randomly initialize one for demonstration purposes; in a real LoRA adaptation setting we would generally load this from a pre-trained model checkpoint.

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([2, 32, 32, 2]),
    jax.random.key(0),
)

Step 1: Freeze parameters#

We’ll start by freezing all the parameters. Learnable parameters are identifiable because they are instances of pz.nn.Parameter:

pz.select(mlp).at_instances_of(pz.nn.Parameter).show_selection()
pz.select(mlp).at_instances_of(pz.nn.Parameter).get_sequence()

We can freeze these parameters by replacing each Parameter with an equivalent FrozenParameter. This is directly tracked inside the structure of the model.

frozen_mlp = pz.select(mlp).at_instances_of(pz.nn.Parameter).apply(
    lambda param: pz.nn.FrozenParameter(param.value, param.name)
)
frozen_mlp
pz.select(frozen_mlp).at_instances_of(pz.nn.Parameter).get_sequence()

Step 2: Replace Linear layers with low-rank adapted versions#

Next, we’ll replace the Linear layers with implementations of LoRA.

In essence, a LoRA block is a sum of two computation paths: one that uses the original linear layer, and one that uses a sequence of two linear operations. This pattern can be directly mapped to one of Penzai’s simple built-in combinators, BranchAndAddTogether. We can take each linear layer, like this one:

frozen_mlp.sublayers[0].sublayers[0]

And replace it with a block like this:

pz.nn.BranchAndAddTogether([
    # The original layer with frozen parameters:
    pz.nn.NamedGroup("Pretrained", [
        frozen_mlp.sublayers[0].sublayers[0],
    ]),
    # And a low-rank adapter:
    pz.nn.NamedGroup("Update", [
        pz.nn.add_parameter_prefix(
            "LoRA-A",
            pz.nn.Linear.from_config(
                input_axes={"features": 8},
                output_axes={"lowrank": 2},
            ),
        ),
        pz.nn.add_parameter_prefix(
            "LoRA-B",
            pz.nn.Linear.from_config(
                input_axes={"lowrank": 2},
                output_axes={"features_out": 8},
                initializer=pz.nn.zero_initializer,
            ),
        ),
    ]),
])

Note that the above code is a direct translation of a LoRA block into the structure of our model. The matrices A and B are represented as separate Linear blocks inside the overall combinator, and the order of execution is determined by the positions in the NamedGroup.

To simplify the process of making this transformation at every Linear block, we can encapsulate it into a new Layer subclass. Since the computation can already be written as a combination of existing pieces, the idiomatic Penzai approach is to define our new Layer as a subclass of pz.nn.Sequential, so that it can be easily flattened (like we did with the MLP) id needed. Sequential already defines the necessary attributes and __call__ method, so we just need to provide a named initializer:

@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class LowRankAdapter(pz.nn.Sequential):

  @classmethod
  def from_linear(
      cls,
      linear: pz.nn.Linear,
      rank: int,
      name: str,
      lowrank_axis: str = "lowrank",
  ) -> 'LowRankAdapter':
    """Builds a LoRA layer from a Linear layer.

    Args:
      linear: The linear layer to adapt.
      rank: The rank of the low-rank adapter.
      name: Prefix for this block's parameters.
      lowrank_axis: The axis name for low-rank adaptation.

    Returns:
      A LoRA block with uninitialized parameters and the same initial
      behavior as `linear`.
    """
    return cls([
        pz.nn.BranchAndAddTogether([
            pz.nn.NamedGroup("Pretrained", [linear]),
            pz.nn.NamedGroup("Update", [
                pz.nn.add_parameter_prefix(
                    name + "/LoRA_A",
                    pz.nn.Linear.from_config(
                        input_axes=linear.input_axes,
                        output_axes={lowrank_axis: rank},
                        parallel_axes=linear.parallel_axes,
                    ),
                ),
                pz.nn.add_parameter_prefix(
                    name + "/LoRA_B",
                    pz.nn.Linear.from_config(
                        input_axes={lowrank_axis: rank},
                        output_axes=linear.output_axes,
                        parallel_axes=linear.parallel_axes,
                        initializer=pz.nn.zero_initializer,
                    ),
                ),
            ]),
        ])
    ])

Note: Idiomatic Penzai layers generally avoid overriding __init__, since dataclasses take their attributes as parameters to __init__ and we want to ensure the output of the pretty-printer directly corresponds to code we could use to rebuild the model even if we’ve modified its attributes. When we have nontrivial construction logic, we’ll usually define it in a class method like from_linear or from_config instead.

Also, in most Penzai layers, each layer is only responsible for ensuring it’s parameter names are locally unique, and parent layers add parameter prefixes using pz.nn.add_parameter_prefix at each level. In this case, however, we’re planning on inserting the LoRA blocks into an existing model, so the names must be globally unique. This is why from_linear takes a name as an argument but Linear.from_config does not.

The next step is to write a helper function for inserting LoRA blocks into a model. We’ll use Penzai’s pretty_keystr function (a fancier version of jax.tree_util.keystr) to ensure each block has a unique name:

def loraify_all_linears(model, rank: int):
  return (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              rank=rank,
              name=pz.pretty_keystr(keypath, model),
          ),
          with_keypath=True,
      )
  )

Now we can run it on our MLP:

loraified_mlp_uninit = loraify_all_linears(frozen_mlp, rank=2)
loraified_mlp_uninit

You can directly check that this transformation is doing the right thing by expanding each Affine layer and making sure the LowRankAdapter looks right.

Note that loraified_mlp_uninit is a copy of frozen_mlp with the requested modifications. In Penzai, transformations of models always return new copies of the model, so you don’t have to worry about accidentally making an irreversible change.

(Only the model structure is copied; the JAX arrays still share memory between the models. But JAX arrays are immutable, so you don’t have to worry about those changing either; unless you explicitly delete or donate them, training loops always return new copies of your model with updated parameters.)

Step 3: Initializing and training the LoRA weights#

Finally, we can initialize and train the new weights we inserted into the model. To initialize them, we can use the standard Penzai parameter initialization helper function, which finds all UninitializedParameter instances and initializes them. In this case, the UninitializedParameters are the LoRA weights, and the FrozenParameters from the pretrained model are ignored.

%%autovisualize
loraified_mlp = pz.nn.initialize_parameters(loraified_mlp_uninit, jax.random.key(42))
loraified_mlp

Since we froze the “pretrained” parameters before we applied loraify_all_linears, the learnable parameters of our new model are just the new LoRA weights:

pz.select(loraified_mlp).at_instances_of(pz.nn.Parameter).get_sequence()

This means you can easily train them using Penzai’s basic training loop helpers, or write your own custom training loop for them. As a demonstration, we’ll train this model to implement XOR by only fitting the low-rank adapter parameters.

xor_inputs = pz.nx.wrap(
    jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
    "batch",
    "features",
)
xor_labels = jnp.array([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32)

def loss_fn(model, rng, state):
  assert state is None
  model_out = model(xor_inputs)
  log_probs = jax.nn.log_softmax(
      model_out.unwrap("batch", "features"), axis=-1
  )
  losses = -log_probs * xor_labels
  loss = jnp.sum(losses) / 4
  return loss, None, {"loss": loss}
train_step = basic_training.build_train_step_fn(loss_fn)
train_state = basic_training.TrainState.initial_state(
    model=loraified_mlp,
    optimizer_def=optax.adam(0.1),
    root_rng=jax.random.key(42),
)
train_state
for i in range(20):
  train_state, out = train_step(train_state)
  print(i, out)
train_state

TrainState is an optional utility that manages the optimizer states for us. It also partitions the model into learnable and nonlearnable parts, but we can combine them again by reading the computed property train_state.model:

%%autovisualize
train_state.model(xor_inputs)
%%autovisualize
pz.nx.nmap(jnp.argmax)(train_state.model(xor_inputs).untag("features"))

Looks like it worked!

Adding LoRA to Gemma#

Let’s now try adding LoRA to the Gemma 7B pretrained model. Because of Penzai’s compositional design, the implementation in the previous section will just work out of the box!

Loading Gemma#

We’ll start by loading the weights from the Gemma checkpoint. We’ll use the 7B checkpoint for this tutorial, and shard it over our local devices using JAX’s automatic partitioning. (You can read more about JAX’s automatic distributed arrays on this JAX documentation page.)

If you prefer, you can also run this tutorial with the 2B checkpoint.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit https://www.kaggle.com/ and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax:

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)
gemma_model = gemma.model_core.GemmaTransformer.from_pretrained(
    flat_params,
    upcast_activations_to_float32=True,
)
del flat_params
gc.collect()

Here’s what the Gemma model looks like:

%%autovisualize
gemma_model

Try clicking the triangle markers to explore the structure of Gemma and look at some of the parameters!

Converting Gemma#

Now we can freeze its parameters and LoRA-ify its linear blocks in the same way that we did for the simple MLP.

The Penzai implementation of Gemma uses the same Linear layer to implement all of the learnable operations, in both the MLP blocks and the attention blocks. So we’ll use a slightly-modified helper function that lets us be more specific about which Linear layers we want to replace.

def loraify_linears_in_selection(selection, rank: int):
  model = selection.deselect()
  return (
      selection
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              rank=rank,
              name=pz.pretty_keystr(keypath, model),
          ),
          with_keypath=True,
      )
  )

Now we go through and apply each of the transformation steps:

# Step 1: Freeze the pretrained parameters.
frozen_gemma_model = (
    pz.select(gemma_model)
    .at_instances_of(pz.nn.Parameter)
    .apply(
        lambda param: pz.nn.FrozenParameter(param.value, param.name)
    )
)
# Step 2: LoRA-ify the Linear blocks. Following Hu et al. (2021), we'll only
# LoRA-ify the attention parameters.
loraified_gemma_model_uninit = loraify_linears_in_selection(
    pz.select(frozen_gemma_model).at_instances_of(gemma.model_core.GemmaAttention),
    rank=16,
)
# Step 3: Initialize the new LoRA parameters.
loraified_gemma_model = pz.nn.initialize_parameters(
    loraified_gemma_model_uninit, jax.random.key(123)
)
# Step 4 (optional): Look at it to make sure the transformation looks right.
pz.select(loraified_gemma_model).at_instances_of(LowRankAdapter).show_selection()

If we wanted, we could have just as easily adapted the MLP layers, by changing

.at_instances_of(gemma.model_core.GemmaAttention)

to

.at_instances_of(gemma.model_core.GemmaFeedForward)

We could have also customized which blocks have LoRA parameters by using Penzai’s selector system (see the separate selectors tutorial for more details).

Fine-tuning Gemma with LoRA#

We can now fine-tune our LoRA-ified Gemma model, with full control over the training loop.

For this tutorial, we’ll just generate some synthetic data. Specifically, we’ll show it some examples of evaluating a mysterious function, and train it to figure out what the function does. We won’t worry too much about efficiency of the data pipeline, since our goal is just to show how LoRA fine-tuning could work.

def mystery_function(a, b):
  return a + b
def generate_example(np_rng):
  a, b = np_rng.choice(1000, size=(2,))
  c = mystery_function(a, b)
  return f">>> mystery_function({a}, {b})\n{c}"
def tokenize_batch(examples, pad_length=32, include_eos=True):
  padded_tokens = []
  for example in examples:
    example_tokens = [vocab.bos_id()] + vocab.EncodeAsIds(example)
    if include_eos:
      example_tokens = example_tokens + [vocab.eos_id()]
    assert len(example_tokens) <= pad_length
    example_tokens = example_tokens + [vocab.pad_id()] * (pad_length - len(example_tokens))
    padded_tokens.append(example_tokens)
  return pz.nx.wrap(jnp.array(padded_tokens)).tag("batch", "seq")

Penzai has some useful utilities for visualizing token arrays:

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
np_rng = np.random.default_rng(123)
input_examples = tokenize_batch([generate_example(np_rng) for _ in range(20)])
input_examples
token_visualization.show_token_array(input_examples, vocab)

Let’s train our new parameters on this data:

def xent_loss_fn(model, rng, state, input_examples):
  del rng, state  # Unused.
  # Run the model on shifted examples.
  # `GemmaInputs.from_basic_segments` is responsible for building the causal
  # attention mask and setting up positional embeddings.
  outputs = model(gemma.model_core.GemmaInputs.from_basic_segments(
      input_examples[{"seq": pz.slice[:-1]}]
  ))
  # Compute log-probabilities along the "vocabulary" axis.
  all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      outputs.untag("vocabulary")
  ).tag("vocabulary")
  # Index by the correct tokens.
  correct_next_tokens = input_examples[{"seq": pz.slice[1:]}]
  correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]
  # Mask padding tokens.
  correct_log_probs = pz.nx.nmap(jnp.where)(
      correct_next_tokens == vocab.pad_id(),
      0.0,
      correct_log_probs,
  )
  # Take averages.
  loss = -correct_log_probs.untag("batch", "seq").unwrap().mean()
  return loss, None, {"loss": loss}
train_step = basic_training.build_train_step_fn(xent_loss_fn, donate_params_and_state=True)
train_state = basic_training.TrainState.initial_state(
    model=loraified_gemma_model,
    optimizer_def=optax.adamw(5e-5, weight_decay=0.01),
    root_rng=jax.random.key(42),
)
np_rng = np.random.default_rng(123)
# Train on 200 batches of 16 examples -> 3,200 examples
# (For reference, there are 1000 * 1000 = 1,000,000 possible examples in the
# synthetic distribution we are using.)
print_steps = {*range(10), *range(10, 200, 10)}
while train_state.step < 200:
  input_examples = tokenize_batch([
      generate_example(np_rng) for _ in range(16)
  ])
  train_state, out = train_step(train_state, input_examples=input_examples)
  if train_state.step in print_steps:
    print(train_state.step, out)

To see what the model learned, we can pull out the model from the train state and look at its parameters. In this case, all of the learnable parameters were added by our LoRA adapter.

We’ll turn on the autovisualizer so that we can see the distribution of values in the arrays at a glance; try clicking on a few to expand their visualizations.

%%autovisualize
pz.select(train_state.model).at_instances_of(pz.nn.Parameter).get_sequence()

Recall that we initialized all of the “B” matrices to zero. So the fact that they are no longer zero indicates that the model has definitely learned something!

But has it learned what we wanted? Let’s try running it on a randomly sampled batch of examples.

%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
np_rng = np.random.default_rng(98765)
validation_examples = tokenize_batch([generate_example(np_rng) for _ in range(32)])
validation_examples
token_visualization.show_token_array(validation_examples, vocab)
outputs = train_state.model(gemma.model_core.GemmaInputs.from_basic_segments(
    validation_examples[{"seq": pz.slice[:-1]}]
))
# Compute log-probabilities along the "vocabulary" axis.
all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
    outputs.untag("vocabulary")
).tag("vocabulary")

# Index by the correct tokens.
correct_next_tokens = validation_examples[{"seq": pz.slice[1:]}]
correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]

# Plot the probability of the correct digit.
# This uses the same renderer as %%autovisualize, but doesn't truncate the array
# and lets us mask out elements.
pz.ts.render_array(
    pz.nx.nmap(jnp.exp)(correct_log_probs),
    valid_mask=(correct_next_tokens != vocab.pad_id()),
)

We can see that the model is predicting the arguments to mystery_function with about 10% accuracy, which is reasonable because those digits are random. It also seems to be almost perfectly accurate on the answers, indicating that it has successfully fit the distribution.

Running inference on our LoRA-ified model#

Now that we’ve fine-tuned the model, we can convert it into decoding mode and sample from it.

In Penzai, autoregressive decoding is performed by a separate class GemmaKVCachingTransformer, instead of being an alternative mode of GemmaTransformer. This is an instance of a more general pattern in Penzai models: each model and layer does a single thing at runtime, instead of doing different things depending on what arguments you pass. In fact, idiomatic Penzai layers always define a single function __call__, and that function always takes a single argument (although that argument can be a dictionary or tuple if needed). This makes it easy to compose many layers together in a uniform way without having to worry about how to handle function arguments.

The decoding mode transformation is actually very similar to the LoRA adaptation transformation we defined above. Instead of replacing Linear blocks with new LowRankAdapter blocks (which have new parameters), this transformation replaces Attention blocks with KVCachingAttention blocks (which have new state variables).

Since the key-value caching for Gemma is itself implemented as a patching transformation, this means that key-value caching can be immediately applied to our final train_state.model even though we’ve already edited the model structure to add new adapted parameters. Our modifications don’t conflict with the attention block structure, so the modifications can be easily composed.

Here’s how we can enable decoding mode:

finetuned_inference_model, initial_inference_state = (
  gemma.sampling_mode.GemmaKVCachingTransformer.from_uncached(
      train_state.model,
      cache_len=64,
      batch_axes={"batch": 4},
  )
)

Let’s look inside to see the changes:

# You can use a function to pick out an initial node to expand in the
# visualization. (You can also copy such a function by clicking the grey copy
# icon at the end of each line.)
pz.select(finetuned_inference_model).at(
    (lambda root: root.body.body.body.body.body.sublayers[5].sublayers[0].delta.sublayers[1].kv_cache)
).show_value()

The LowRankAdapter classes we inserted are still there in the model, but there have been a few other changes to the model structure:

  • The outermost class is of a different type GemmaKVCachingTransformer.

  • Inside it, there’s a new WithFunctionalLocalState wrapper, which is responsible for managing the key-value caches, and a new WithSideInputsFromInputTuple wrapper that manages the current decoding position.

  • Inside each of the transformer blocks, the GemmaAttention layers have been replaced with new GemmaKVCachingAttention layers that point back to these two new wrappers.

Now that we’ve converted the model, we can use some existing helper functions to sample from it. (We discuss the decoding mode and helper functions more in the separate “Gemma From Scratch” tutorial. We’ll wrap our model in Jitted so that it JIT-compiles itself whenever it is called.

prompts = [
    ">>> mystery_function(123, 123)",
    ">>> mystery_function(101, 15)",
    ">>> mystery_function(999, 876)",
    ">>>", # Let the model write and solve its own problem
]
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokenized_prompts = tokenize_batch(prompts, 16, include_eos=False)
tokenized_prompts
samples = gemma.simple_decoding_loop.temperature_sample_pyloop(
    jit_wrapper.Jitted(finetuned_inference_model),
    initial_inference_state,
    prompt=tokenized_prompts,
    rng=jax.random.key(3),
    pad_id=vocab.pad_id(),
    max_sampling_steps=20,
)
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
pz.show(samples)
token_visualization.show_token_array(samples, vocab)

As desired, our fine-tuned model seems to have learned the behavior of mystery_function using the low-rank updates to its parameters.

Recap#

This notebook demonstrates how Penzai makes it easy to edit the structure of a pretrained model without requiring any changes to the original model implementation. Our LowRankAdapter class and associated utilities took less than a hundred lines of code, and were immediately compatible with the pretrained Gemma 7B model, including both training and sampling modes.

The definitions in this notebook are also available in penzai.toolshed.lora, and can be imported from there if you are interested in using Penzai to perform parameter-efficient fine-tuning.

However, LoRA is just one example of what you can do with Penzai’s powerful patching and model rewriting utilities. The key-value caching transformation is another, which we discuss in the “Gemma From Scratch” tutorial. And these tools can also be used to study intermediate activations and perform targeted counterfactual interventions to specific layers in the model, which we discuss in the “Induction Heads” tutorial. Penzai is designed to simplify the general process of editing, visualizing, and analyzing pretrained models; the goal is not to implement every possible type of fine-tuning or patching, but instead to give you powerful general-purpose tools and then get out of your way.