Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

LoRA From Scratch - Patching Pretrained Models in Penzai#

Penzai is designed to make it easy to make targeted modifications to neural networks after they have been trained. In this notebook, we’ll show how to take Penzai’s reference implementation of Gemma 7B open-weights transformer model, patch it to support Low-Rank Adaptation (LoRA Hu et al. 2021), and train the new parameters on a toy problem with a hand-written loss function.

The goal of this notebook is to show how you could implement something like LoRA from scratch in less than a hundred lines of code, starting from a Penzai implementation of a model that doesn’t support it already, and without having to fork the existing implementation source code or even modify the pretrained model’s configuration. We’ll define everything we need as we go and make changes to models interactively. In fact, our implementation will end up being completely modular; we’ll start by applying LoRA to a small MLP and then immediately be able to transfer our implementation to Gemma 7B.

Let’s get started!

Setup#

Before we can get started in earnest, we need to set up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import os
import gc
import collections

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils
import penzai
from penzai import pz
import sentencepiece as spm
from penzai.example_models import gemma
from penzai.example_models import simple_mlp
from penzai.toolshed import basic_training
from penzai.toolshed import token_visualization
from penzai.toolshed import jit_wrapper

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

Intro to Penzai’s declarative combinator design#

We’ll start by giving a brief introduction to Penzai’s design conventions, and how they make it easy to insert adapters into pretrained models. Let’s begin by initializing a small MLP:

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([8, 32, 32, 8]),
    jax.random.key(0),
)

Like most Penzai models and layers, this MLP takes named arrays as input and returns them as output. A named array is just a wrapped JAX array where a subset of its positional axes have been tagged with names. (See the named axes tutorial for more info on how to use Penzai’s named axis system.)

We can call the MLP directly on an array of inputs to run it:

%%autovisualize
mlp(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

Penzai models are written in a declarative, combinator-based style. This means that the structure of the model directly matches the sequence of high-level operations that the model will run in its forward pass. Composite models, like our MLP, just hold onto their sublayers in a list and run these sublayers in order. Primitive layers, like Linear, hold on to their parameters as attributes instead of reading them from an external parameter dictionary.

We can see the sublayers by pretty-printing the model:

%%autovisualize
mlp

By convention, most of the “complicated” logic in Penzai model classes happens when we initialize them, using the .from_config method we called earlier. Once the model is built, the pretty-printed representation provides a full specification of everything the model does, and the parameters are stored as direct attributes on the layers that need them. A general design principle of Penzai is “what you see is what you get”; you should be able to learn everything you need to know about a model by printing it out.

In fact, you can click on a pretty-printed output, press r to add qualified names to the pretty-printed visualization (try it above!), and then copy and paste the entire pretty-printed code and execute it to make a copy of the model:

penzai.example_models.simple_mlp.MLP(
  sublayers=[
    penzai.nn.linear_and_affine.Affine(
    sublayers=[
        penzai.nn.linear_and_affine.LinearInPlace(
          sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([[ 0.04725012, -0.23066601, -0.33130226,  0.14961348,  0.0566796 ,\n        -0.00751991,  0.38097695, -0.19894585,  0.10838961,  0.03168617,\n         0.18273951, -0.19517657,  0.10663664,  0.03312971, -0.18675321,\n        -0.18642329,  0.1921328 ,  0.10855018, -0.02933162,  0.28482816,\n         0.2975921 , -0.02662617,  0.2734555 , -0.19869885,  0.24985315,\n         0.05980473,  0.20454031, -0.14542156, -0.12331318, -0.09077168,\n        -0.08418605, -0.11115803],\n       [-0.06492946, -0.11429147,  0.30805257, -0.32192293,  0.3667109 ,\n         0.27841014, -0.24174255, -0.24910982, -0.27875087,  0.35143396,\n         0.07689688, -0.01147286, -0.3262384 ,  0.02630419, -0.3234371 ,\n        -0.27148998, -0.3160725 , -0.36660823,  0.34744027,  0.12888198,\n         0.23805176, -0.24385388,  0.2650048 ,  0.03819488, -0.00568975,\n        -0.0215622 ,  0.2193515 , -0.3253995 ,  0.09929291, -0.29280028,\n        -0.26032737, -0.20181446],\n       [ 0.01659251,  0.36555287, -0.38305447,  0.18675229,  0.03717915,\n         0.12696777, -0.25671378,  0.17499834, -0.22098203, -0.15922017,\n         0.2351897 ,  0.32522848,  0.07117841,  0.33022884, -0.06571785,\n         0.13955157, -0.0600304 , -0.19759853,  0.3037875 ,  0.30708078,\n        -0.07626879,  0.35707024,  0.19575489,  0.15175632,  0.10577198,\n         0.00364989,  0.18226433,  0.08367107,  0.29136872, -0.18551245,\n        -0.14171803, -0.08811028],\n       [ 0.38122043,  0.23902583,  0.17216133,  0.06905378,  0.16787706,\n        -0.08580669, -0.22893907, -0.34190187, -0.28671974,  0.23940544,\n        -0.16707611, -0.08652407, -0.23550698,  0.1630661 ,  0.11646476,\n         0.03909842, -0.14891088, -0.21910995,  0.18650232, -0.03763161,\n         0.26505992, -0.37263677,  0.01191313, -0.24201725, -0.04657881,\n        -0.36401856,  0.3515382 ,  0.2873264 ,  0.08209115,  0.14519285,\n         0.31111813, -0.00785317],\n       [-0.08459742,  0.25045004, -0.28495365,  0.33237663, -0.11983901,\n         0.27122143,  0.36026204, -0.24545845,  0.10550512,  0.30733794,\n         0.11679422,  0.21981671,  0.1890508 , -0.36632678,  0.1600768 ,\n         0.09827441,  0.08301278, -0.2713667 ,  0.36596537, -0.19401632,\n         0.08262127, -0.32878917,  0.33614868,  0.0187863 , -0.08504166,\n         0.20853564, -0.37342063,  0.04217026,  0.15900622,  0.05655929,\n        -0.3844944 , -0.24541588],\n       [ 0.29031137, -0.3365837 , -0.14176023, -0.15418899, -0.27481785,\n        -0.2589707 , -0.0712328 , -0.03633535, -0.02502658, -0.1892304 ,\n         0.05823506,  0.06479566,  0.3832009 ,  0.08127026, -0.08600669,\n        -0.1306198 ,  0.3590402 , -0.26903337, -0.16024597, -0.28321084,\n        -0.22990172, -0.3255212 ,  0.04306457, -0.15599836,  0.13166516,\n         0.18247025, -0.35739547,  0.23415357,  0.27890775, -0.04992555,\n        -0.14944552, -0.31249046],\n       [ 0.36270365, -0.35655767, -0.11624786,  0.07357443, -0.20815234,\n        -0.32495487, -0.32694384,  0.11253693,  0.34765404,  0.3648101 ,\n        -0.24689904,  0.06159897, -0.1283295 ,  0.19177036, -0.13034305,\n        -0.04315192, -0.18603453,  0.2254531 ,  0.12818351, -0.01486595,\n        -0.11669718, -0.27403483,  0.14894357, -0.09331072, -0.26205596,\n         0.18480024,  0.20129958, -0.13770516,  0.37675482, -0.31014258,\n        -0.37200913,  0.22289428],\n       [ 0.16453928, -0.21428025, -0.05870866,  0.19057752, -0.31916317,\n         0.0421687 ,  0.12571178, -0.12293893,  0.08633191, -0.38240284,\n         0.38143364, -0.3229353 ,  0.22557646,  0.05067461, -0.20315255,\n         0.24650854, -0.36949688, -0.01077681,  0.14819378, -0.22307158,\n         0.205777  ,  0.23953518,  0.1715395 , -0.23309591,  0.10060634,\n        -0.34637597,  0.11265402, -0.18391618, -0.03978348,  0.01059979,\n         0.09037849,  0.05671118]], dtype=float32)', original_id=20251670234640, original_type=jax.Array)), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',))],
        ),
        penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)', original_id=20251431144464, original_type=jax.Array)), name='Affine_0.AddBias.bias'), new_axis_names=()),
      ],
    ),
    penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.nn.linear_and_affine.Affine(
      sublayers=[penzai.nn.linear_and_affine.LinearInPlace(sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([[-0.14293982, -0.12958716,  0.1908599 , ...,  0.05771805,\n         0.20101726, -0.21264127],\n       [ 0.29049423,  0.04062944, -0.29115948, ..., -0.29757395,\n         0.25379792, -0.17268081],\n       [ 0.1237077 , -0.01195317,  0.26793247, ...,  0.09746177,\n         0.18667588, -0.2871105 ],\n       ...,\n       [-0.30499855,  0.21407223,  0.12929553, ...,  0.2603329 ,\n         0.12992881, -0.02856655],\n       [-0.13158074, -0.10787383,  0.27249086, ...,  0.2223304 ,\n         0.00778098,  0.10102441],\n       [-0.24134454, -0.0659571 ,  0.09705041, ...,  0.20444332,\n        -0.01505233, -0.09814899]], dtype=float32)', original_id=20251416478480, original_type=jax.Array)), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',))]), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)', original_id=20251219338000, original_type=jax.Array)), name='Affine_1.AddBias.bias'), new_axis_names=())],
    ),
    penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.nn.linear_and_affine.Affine(
      sublayers=[penzai.nn.linear_and_affine.LinearInPlace(sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([[-0.02120503, -0.26273215,  0.19671364, -0.25535676,  0.33353078,\n         0.34760946,  0.11154714, -0.36327165],\n       [-0.18882798,  0.29033083, -0.01784066, -0.35918185, -0.1183674 ,\n        -0.09689117, -0.22195095,  0.17173886],\n       [-0.17731062, -0.30207223, -0.25690556,  0.13724743,  0.18264246,\n         0.14667968, -0.09792759,  0.28671548],\n       [ 0.13883723, -0.36321005,  0.3602296 ,  0.29703382, -0.14947332,\n         0.38648838,  0.33616605, -0.2456858 ],\n       [ 0.29338026,  0.29596722, -0.26344863,  0.19266097,  0.01534889,\n        -0.21687248, -0.37927428, -0.15640356],\n       [-0.26717016,  0.19525719, -0.11175556, -0.21094975, -0.36075762,\n        -0.06405002,  0.28700718,  0.02164715],\n       [ 0.12978551, -0.11752943, -0.12858906,  0.21853136,  0.09607212,\n         0.31332853, -0.01224952, -0.28824407],\n       [-0.10415927, -0.0974042 ,  0.05993234, -0.3562885 ,  0.23716317,\n        -0.37986112, -0.24494588, -0.29328707],\n       [ 0.04717422, -0.04239095,  0.27582565, -0.37281165,  0.00247967,\n        -0.23975734, -0.0299767 ,  0.16125403],\n       [-0.07960048,  0.3762669 , -0.36568752,  0.23300837,  0.10682021,\n         0.17286457, -0.3787364 ,  0.24803638],\n       [-0.3089405 ,  0.10941845,  0.03270726, -0.3604047 ,  0.05435884,\n         0.07881735, -0.22836307,  0.33140016],\n       [-0.20255178, -0.37571028,  0.20740984, -0.02714456,  0.08295553,\n        -0.33837444,  0.26255652,  0.36039695],\n       [ 0.02880464,  0.1907816 ,  0.19101946,  0.05349049,  0.13013566,\n         0.33623788,  0.36152375,  0.26223242],\n       [-0.12129059, -0.05046343, -0.22875458, -0.22711188, -0.28795338,\n        -0.0417299 , -0.12364311, -0.20053242],\n       [ 0.15858358, -0.21735588,  0.2747909 ,  0.00585236, -0.2112107 ,\n        -0.26138678,  0.21196207,  0.13825993],\n       [ 0.02330455,  0.22529002,  0.11926983, -0.02966321,  0.14428146,\n         0.11442037, -0.33265772,  0.23587993],\n       [-0.37871405,  0.28270426, -0.21009018,  0.06479418, -0.36493284,\n        -0.3663771 , -0.3676921 , -0.17661622],\n       [ 0.06361667,  0.21918733,  0.24235171, -0.06552726,  0.27533627,\n         0.2639572 , -0.02761605,  0.34604597],\n       [-0.34795552, -0.02057168, -0.00230321, -0.20056584, -0.29356706,\n         0.37752953, -0.34542397,  0.35203296],\n       [-0.07879667, -0.03717786,  0.01729226, -0.15346181, -0.36752957,\n        -0.23917256, -0.3446656 ,  0.19485265],\n       [ 0.2728985 ,  0.24536206, -0.24228264, -0.06692472,  0.34564835,\n        -0.10513262,  0.2722488 , -0.3408531 ],\n       [-0.2712835 ,  0.15825975,  0.32759133, -0.20761152,  0.06780361,\n         0.20713052,  0.02732306, -0.15797636],\n       [ 0.02223683,  0.10032009, -0.05367341,  0.20788151, -0.08944707,\n        -0.37642628,  0.16834743,  0.17841148],\n       [ 0.03506837, -0.37967107,  0.06165724,  0.16286056, -0.27745268,\n        -0.27329233, -0.12189328,  0.36668006],\n       [-0.37223703, -0.00192019, -0.03584365,  0.31333244,  0.12743935,\n        -0.27959207, -0.18257874, -0.3266675 ],\n       [-0.17420247, -0.18477465,  0.1551072 , -0.03858141,  0.3605494 ,\n         0.16175184, -0.10405576, -0.25369927],\n       [ 0.19931318,  0.1297951 , -0.28167292, -0.04960836,  0.22335662,\n         0.14375217,  0.11561006,  0.14631467],\n       [ 0.3007057 , -0.30685687,  0.12430204, -0.33081028,  0.07697242,\n         0.00636965,  0.1470099 ,  0.33340123],\n       [-0.26456222,  0.19320115, -0.36342096, -0.36083212, -0.3115429 ,\n        -0.3427212 , -0.10350902, -0.35717598],\n       [ 0.12212099,  0.27373436,  0.03613165,  0.30182862,  0.12713408,\n         0.26485837, -0.27315864,  0.29892513],\n       [ 0.07317165,  0.26054114,  0.2656967 ,  0.05795342,  0.34420094,\n         0.15277314,  0.255256  ,  0.3382493 ],\n       [-0.09600905,  0.3696218 ,  0.36078495, -0.3213724 , -0.1559619 ,\n         0.1136618 , -0.11645275, -0.38116798]], dtype=float32)', original_id=20251403998480, original_type=jax.Array)), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',))]), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)', original_id=20251670240016, original_type=jax.Array)), name='Affine_2.AddBias.bias'), new_axis_names=())],
    ),
  ],
)

We won’t usually do this in practice, because device arrays can’t be copy-pasted this way; the parameters will be replaced with placeholder objects. Instead, Penzai provides a sophisticated selector system (pz.select) that allow us to make targeted modifications to (copies of) models. The point here is that Penzai model objects aren’t “hiding” anything; they directly expose the structure of their computation as a data structure that can be manipulated.

The specific types of Penzai models and composite layers are provided primarily for ease of manipulation and as a way to identify how each part of your model was built. But you can also “flatten” a model into a list of sublayers that run in sequence, discarding this extra information:

pz.nn.inline_groups(mlp, lambda _: True, lambda _: True)

And you can freely add new logic as well, even if it wasn’t configured in the initial model. For instance, here’s how you could insert a new layer that prints out its intermediate activation:

@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.Layer):  # <- pz.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value):
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value
patched = (
    pz.select(mlp)
    .at(lambda model: model.sublayers[2])
    .insert_after(DisplayIntermediateValue())
)
pz.select(patched).at_instances_of(DisplayIntermediateValue).show_selection()

patched is a copy of our model that includes our new layer, and it will run our new logic when the model is called:

%%autovisualize
patched(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

This ability makes it remarkably easy to implement adapters like LoRA!

Building a simple LoRA Layer in Penzai#

Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning strategy that augments each linear operation in the model with a decomposed low-rank adapter. The original weight matrix is frozen, and two smaller learnable parameter matrices are used to perturb its output. These parameters are kept separate from the original matrix, so gradients of these new parameters can be easily updated in a compute- and memory-efficient way.

The effective weight matrix can be decomposed like this:

 ┌────────────────┐       ┌─────┐                      
 │                │       │     │                      
 │                │       │  A: │   ┌────────────────┐
 │    W: d*d      │   +   │ d*r │ * │     B: r*d     │
 │                │       │     │   └────────────────┘
 │                │       │     │                      
 └────────────────┘       └─────┘                      

Here W is the original frozen weight matrix, A is a randomly-initialized matrix, and B is initialized to zero to ensure that the adapted model is equivalent to the original one at initialization.

To enable LoRA, we’ll do three things for each linear layer in our model:

  • Freeze the original weight,

  • Initialize our low-rank matrices A and B,

  • And replace the original linear layer with the composition of W, A, and B.

Let’s try it out with a simple MLP like the one we built in the last section. We’ll just randomly initialize one for demonstration purposes; in a real LoRA adaptation setting we would generally load this from a pre-trained model checkpoint.

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([2, 32, 32, 2]),
    jax.random.key(0),
)

Step 1: Freeze parameters#

We’ll start by freezing all the parameters. Learnable parameters are identifiable because they are instances of pz.nn.Parameter:

pz.select(mlp).at_instances_of(pz.nn.Parameter).show_selection()
pz.select(mlp).at_instances_of(pz.nn.Parameter).get_sequence()

We can freeze these parameters by replacing each Parameter with an equivalent FrozenParameter. This is directly tracked inside the structure of the model.

frozen_mlp = pz.select(mlp).at_instances_of(pz.nn.Parameter).apply(
    lambda param: pz.nn.FrozenParameter(param.value, param.name)
)
frozen_mlp
pz.select(frozen_mlp).at_instances_of(pz.nn.Parameter).get_sequence()

Step 2: Replace Linear layers with low-rank adapted versions#

Next, we’ll replace the Linear layers with implementations of LoRA.

In essence, a LoRA block is a sum of two computation paths: one that uses the original linear layer, and one that uses a sequence of two linear operations. This pattern can be directly mapped to one of Penzai’s simple built-in combinators, BranchAndAddTogether. We can take each linear layer, like this one:

frozen_mlp.sublayers[0].sublayers[0]

And replace it with a block like this:

pz.nn.BranchAndAddTogether([
    # The original layer with frozen parameters:
    pz.nn.NamedGroup("Pretrained", [
        frozen_mlp.sublayers[0].sublayers[0],
    ]),
    # And a low-rank adapter:
    pz.nn.NamedGroup("Update", [
        pz.nn.add_parameter_prefix(
            "LoRA-A",
            pz.nn.Linear.from_config(
                input_axes={"features": 8},
                output_axes={"lowrank": 2},
            ),
        ),
        pz.nn.add_parameter_prefix(
            "LoRA-B",
            pz.nn.Linear.from_config(
                input_axes={"lowrank": 2},
                output_axes={"features_out": 8},
                initializer=pz.nn.zero_initializer,
            ),
        ),
    ]),
])

Note that the above code is a direct translation of a LoRA block into the structure of our model. The matrices A and B are represented as separate Linear blocks inside the overall combinator, and the order of execution is determined by the positions in the NamedGroup.

To simplify the process of making this transformation at every Linear block, we can encapsulate it into a new Layer subclass. Since the computation can already be written as a combination of existing pieces, the idiomatic Penzai approach is to define our new Layer as a subclass of pz.nn.Sequential, so that it can be easily flattened (like we did with the MLP) id needed. Sequential already defines the necessary attributes and __call__ method, so we just need to provide a named initializer:

@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class LowRankAdapter(pz.nn.Sequential):

  @classmethod
  def from_linear(
      cls,
      linear: pz.nn.Linear,
      rank: int,
      name: str,
      lowrank_axis: str = "lowrank",
  ) -> 'LowRankAdapter':
    """Builds a LoRA layer from a Linear layer.

    Args:
      linear: The linear layer to adapt.
      rank: The rank of the low-rank adapter.
      name: Prefix for this block's parameters.
      lowrank_axis: The axis name for low-rank adaptation.

    Returns:
      A LoRA block with uninitialized parameters and the same initial
      behavior as `linear`.
    """
    return cls([
        pz.nn.BranchAndAddTogether([
            pz.nn.NamedGroup("Pretrained", [linear]),
            pz.nn.NamedGroup("Update", [
                pz.nn.add_parameter_prefix(
                    name + "/LoRA_A",
                    pz.nn.Linear.from_config(
                        input_axes=linear.input_axes,
                        output_axes={lowrank_axis: rank},
                        parallel_axes=linear.parallel_axes,
                    ),
                ),
                pz.nn.add_parameter_prefix(
                    name + "/LoRA_B",
                    pz.nn.Linear.from_config(
                        input_axes={lowrank_axis: rank},
                        output_axes=linear.output_axes,
                        parallel_axes=linear.parallel_axes,
                        initializer=pz.nn.zero_initializer,
                    ),
                ),
            ]),
        ])
    ])

Note: Idiomatic Penzai layers generally avoid overriding __init__, since dataclasses take their attributes as parameters to __init__ and we want to ensure the output of the pretty-printer directly corresponds to code we could use to rebuild the model even if we’ve modified its attributes. When we have nontrivial construction logic, we’ll usually define it in a class method like from_linear or from_config instead.

Also, in most Penzai layers, each layer is only responsible for ensuring it’s parameter names are locally unique, and parent layers add parameter prefixes using pz.nn.add_parameter_prefix at each level. In this case, however, we’re planning on inserting the LoRA blocks into an existing model, so the names must be globally unique. This is why from_linear takes a name as an argument but Linear.from_config does not.

The next step is to write a helper function for inserting LoRA blocks into a model. We’ll use Penzai’s pretty_keystr function (a fancier version of jax.tree_util.keystr) to ensure each block has a unique name:

def loraify_all_linears(model, rank: int):
  return (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              rank=rank,
              name=pz.pretty_keystr(keypath, model),
          ),
          with_keypath=True,
      )
  )

Now we can run it on our MLP:

loraified_mlp_uninit = loraify_all_linears(frozen_mlp, rank=2)
loraified_mlp_uninit

You can directly check that this transformation is doing the right thing by expanding each Affine layer and making sure the LowRankAdapter looks right.

Note that loraified_mlp_uninit is a copy of frozen_mlp with the requested modifications. In Penzai, transformations of models always return new copies of the model, so you don’t have to worry about accidentally making an irreversible change.

(Only the model structure is copied; the JAX arrays still share memory between the models. But JAX arrays are immutable, so you don’t have to worry about those changing either; unless you explicitly delete or donate them, training loops always return new copies of your model with updated parameters.)

Step 3: Initializing and training the LoRA weights#

Finally, we can initialize and train the new weights we inserted into the model. To initialize them, we can use the standard Penzai parameter initialization helper function, which finds all UninitializedParameter instances and initializes them. In this case, the UninitializedParameters are the LoRA weights, and the FrozenParameters from the pretrained model are ignored.

%%autovisualize
loraified_mlp = pz.nn.initialize_parameters(loraified_mlp_uninit, jax.random.key(42))
loraified_mlp

Since we froze the “pretrained” parameters before we applied loraify_all_linears, the learnable parameters of our new model are just the new LoRA weights:

pz.select(loraified_mlp).at_instances_of(pz.nn.Parameter).get_sequence()

This means you can easily train them using Penzai’s basic training loop helpers, or write your own custom training loop for them. As a demonstration, we’ll train this model to implement XOR by only fitting the low-rank adapter parameters.

xor_inputs = pz.nx.wrap(
    jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
    "batch",
    "features",
)
xor_labels = jnp.array([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32)

def loss_fn(model, rng, state):
  assert state is None
  model_out = model(xor_inputs)
  log_probs = jax.nn.log_softmax(
      model_out.unwrap("batch", "features"), axis=-1
  )
  losses = -log_probs * xor_labels
  loss = jnp.sum(losses) / 4
  return loss, None, {"loss": loss}
train_step = basic_training.build_train_step_fn(loss_fn)
train_state = basic_training.TrainState.initial_state(
    model=loraified_mlp,
    optimizer_def=optax.adam(0.1),
    root_rng=jax.random.key(42),
)
train_state
for i in range(20):
  train_state, out = train_step(train_state)
  print(i, out)
train_state

TrainState is an optional utility that manages the optimizer states for us. It also partitions the model into learnable and nonlearnable parts, but we can combine them again by reading the computed property train_state.model:

%%autovisualize
train_state.model(xor_inputs)
%%autovisualize
pz.nx.nmap(jnp.argmax)(train_state.model(xor_inputs).untag("features"))

Looks like it worked!

Adding LoRA to Gemma#

Let’s now try adding LoRA to the Gemma 7B pretrained model. Because of Penzai’s compositional design, the implementation in the previous section will just work out of the box!

Loading Gemma#

We’ll start by loading the weights from the Gemma checkpoint. We’ll use the 7B checkpoint for this tutorial, and shard it over our local devices using JAX’s automatic partitioning. (You can read more about JAX’s automatic distributed arrays on this JAX documentation page.)

If you prefer, you can also run this tutorial with the 2B checkpoint.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit https://www.kaggle.com/ and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax:

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)
gemma_model = gemma.model_core.GemmaTransformer.from_pretrained(
    flat_params,
    upcast_activations_to_float32=True,
)
del flat_params
gc.collect()

Here’s what the Gemma model looks like:

%%autovisualize
gemma_model