Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

LoRA From Scratch - Patching Pretrained Models in Penzai#

Penzai is designed to make it easy to make targeted modifications to neural networks after they have been trained. In this notebook, we’ll show how to take Penzai’s reference implementation of Gemma 7B open-weights transformer model, patch it to support Low-Rank Adaptation (LoRA Hu et al. 2021), and train the new parameters on a toy problem with a hand-written loss function.

The goal of this notebook is to show how you could implement something like LoRA from scratch in less than a hundred lines of code, starting from a Penzai implementation of a model that doesn’t support it already, and without having to fork the existing implementation source code or even modify the pretrained model’s configuration. We’ll define everything we need as we go and make changes to models interactively. In fact, our implementation will end up being completely modular; we’ll start by applying LoRA to a small MLP and then immediately be able to transfer our implementation to Gemma 7B.

Let’s get started!

Note

This tutorial uses the V2 neural network API, defined in pz.experimental.v2.

Setup#

Before we can get started in earnest, we need to set up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import os
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils
import treescope
import penzai
from penzai import pz
import sentencepiece as spm
from penzai.models import transformer
from penzai.models import simple_mlp
from penzai.toolshed import token_visualization
from penzai.toolshed import basic_training
from penzai.toolshed import jit_wrapper

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s companion pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

treescope.basic_interactive_setup(autovisualize_arrays=False)

Intro to Penzai’s declarative combinator design#

We’ll start by giving a brief introduction to Penzai’s design conventions, and how they make it easy to insert adapters into pretrained models. Let’s begin by initializing a small MLP:

mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[8, 32, 32, 8],
)

Like most Penzai models and layers, this MLP takes named arrays as input and returns them as output. A named array is just a wrapped JAX array where a subset of its positional axes have been tagged with names. (See the named axes tutorial for more info on how to use Penzai’s named axis system.)

We can call the MLP directly on an array of inputs to run it:

%%autovisualize
mlp(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

Penzai models are written in a declarative, combinator-based style. This means that the structure of the model directly matches the sequence of high-level operations that the model will run in its forward pass. Composite models, like our MLP, just hold onto their sublayers in a list and run these sublayers in order. Primitive layers, like Linear, hold on to their parameters as attributes instead of reading them from an external parameter dictionary.

We can see the sublayers by pretty-printing the model:

%%autovisualize
mlp

By convention, most of the “complicated” logic in Penzai model classes happens when we initialize them, using the .from_config method we called earlier. Once the model is built, the pretty-printed representation provides a full specification of everything the model does, and the parameters are stored as direct attributes on the layers that need them. A general design principle of Penzai is “what you see is what you get”; you should be able to learn everything you need to know about a model by printing it out.

In fact, you can click on a pretty-printed output and press r to add qualified names to the pretty-printed visualization (try it above!), which will tell you exactly what type each layer has. (If you remove the parameters first using pz.unbind_params, you can even copy and paste the pretty-printed output to rebuild the model structure!)

Note that many classes are annotated with “Sequential”, which means they are just an informatively-named sequence of other layers that run one after another. You can also “flatten” a model into a list of sublayers that run in sequence, discarding this extra information:

pz.nn.inline_groups(pz.nn.Sequential([mlp]), lambda _: True, lambda _: True)

And you can freely add new logic as well, even if it wasn’t configured in the initial model. For instance, here’s how you could insert a new layer that prints out its intermediate activation:

@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.nn.Layer):  # <- pz.nn.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value, **unused_side_inputs):
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value
patched = (
    pz.select(mlp)
    .at(lambda model: model.sublayers[2])
    .insert_after(DisplayIntermediateValue())
)
pz.select(patched).at_instances_of(DisplayIntermediateValue).show_selection()

patched is a copy of our model that includes our new layer, and it will run our new logic when the model is called:

%%autovisualize
patched(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

This ability makes it remarkably easy to implement adapters like LoRA!

Building a simple LoRA Layer in Penzai#

Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning strategy that augments each linear operation in the model with a decomposed low-rank adapter. The original weight matrix is frozen, and two smaller learnable parameter matrices are used to perturb its output. These parameters are kept separate from the original matrix, so gradients of these new parameters can be easily updated in a compute- and memory-efficient way.

The effective weight matrix can be decomposed like this:

 ┌────────────────┐       ┌─────┐                      
 │                │       │     │                      
 │                │       │  A: │   ┌────────────────┐
 │    W: d*d      │   +   │ d*r │ * │     B: r*d     │
 │                │       │     │   └────────────────┘
 │                │       │     │                      
 └────────────────┘       └─────┘                      

Here W is the original frozen weight matrix, A is a randomly-initialized matrix, and B is initialized to zero to ensure that the adapted model is equivalent to the original one at initialization.

To enable LoRA, we’ll do three things for each linear layer in our model:

  • Freeze the original weight,

  • Initialize our low-rank matrices A and B,

  • And replace the original linear layer with the composition of W, A, and B.

Let’s try it out with a simple MLP like the one we built in the last section. We’ll just randomly initialize one for demonstration purposes; in a real LoRA adaptation setting we would generally load this from a pre-trained model checkpoint.

mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[2, 32, 32, 2],
)

Step 1: Freeze parameters#

We’ll start by freezing all the parameters. Learnable parameters are identifiable because they are instances of pz.Parameter:

pz.select(mlp).at_instances_of(pz.Parameter).show_selection()

In this case, the parameters are also the JAX PyTree leaves of the model. This is because they are mutable objects, and are designed to be updated by optimizers.

jax.tree_util.tree_leaves(mlp)

If needed, we can extract these parameters while safely handling repeated parameters using the function pz.unbind_params:

mlp_with_slots, params = pz.unbind_params(mlp)
pz.show("mlp_with_slots:", mlp_with_slots)
pz.show("params:", params)

In this case, however, we just need to “freeze” the parameters, which makes them immutable. We can do this using pz.freeze_params:

frozen_mlp = pz.freeze_params(mlp)
frozen_mlp
# No more parameters:
pz.select(frozen_mlp).at_instances_of(pz.Parameter).get_sequence()
# Leaves are now ordinary JAX arrays:
jax.tree_util.tree_leaves(frozen_mlp)

Step 2: Replace Linear layers with low-rank adapted versions#

Next, we’ll replace the Linear layers with implementations of LoRA.

In essence, a LoRA block is a sum of two computation paths: one that uses the original linear layer, and one that uses a sequence of two linear operations. This pattern can be directly mapped to one of Penzai’s simple built-in combinators, BranchAndAddTogether. We can take each linear layer, like this one:

frozen_mlp.sublayers[0].sublayers[0]

And replace it with a block like this:

pz.nn.BranchAndAddTogether([
    # The original layer with frozen parameters:
    pz.nn.NamedGroup("Pretrained", [
        frozen_mlp.sublayers[0].sublayers[0],
    ]),
    # And a low-rank adapter:
    pz.nn.NamedGroup("Update", [
        pz.nn.Linear.from_config(
            name="LoRA-A",
            init_base_rng=jax.random.key(1),
            input_axes={"features": 8},
            output_axes={"lowrank": 2},
        ),
        pz.nn.Linear.from_config(
            name="LoRA-B",
            init_base_rng=jax.random.key(1),
            input_axes={"lowrank": 2},
            output_axes={"features_out": 8},
            initializer=pz.nn.zero_initializer,
        ),
    ]),
])

Note that the above code is a direct translation of a LoRA block into the structure of our model. The matrices A and B are represented as separate Linear blocks inside the overall combinator, and the order of execution is determined by the positions in the NamedGroup.

To simplify the process of making this transformation at every Linear block, we can encapsulate it into a new Layer subclass. Since the computation can already be written as a combination of existing pieces, the idiomatic Penzai approach is to define our new Layer as a subclass of pz.nn.Sequential, so that it can be easily flattened (like we did with the MLP) id needed. Sequential already defines the necessary attributes and __call__ method, so we just need to provide a named initializer:

@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class LowRankAdapter(pz.nn.Sequential):

  @classmethod
  def from_linear(
      cls,
      linear: pz.nn.Linear,
      name: str,
      init_base_rng: jax.Array | None,
      rank: int,
      lowrank_axis: str = "lowrank",
  ) -> 'LowRankAdapter':
    """Builds a LoRA layer from a Linear layer.

    Args:
      linear: The linear layer to adapt.
      name: Name for this layer's parameters. Must be globally unique across all
        LoRA blocks; we recommend using `jax.tree_util.keystr` or
        `pz.pretty_keystr` and setting the name based on the path to the
        original Linear layer being replaced.
      init_base_rng: The base RNG to use for initializing model parameters.
      rank: The rank of the low-rank adapter.
      lowrank_axis: The axis name for low-rank adaptation.

    Returns:
      A LoRA block with uninitialized parameters and the same initial
      behavior as `linear`.
    """
    return cls([
        pz.nn.BranchAndAddTogether([
            pz.nn.NamedGroup("Pretrained", [linear]),
            pz.nn.NamedGroup(
                "Update",
                [
                    pz.nn.Linear.from_config(
                        name=f"{name}/LoRA_A",
                        init_base_rng=init_base_rng,
                        input_axes=linear.input_axes,
                        output_axes={lowrank_axis: rank},
                        parallel_axes=linear.parallel_axes,
                    ),
                    pz.nn.Linear.from_config(
                        name=f"{name}/LoRA_B",
                        init_base_rng=init_base_rng,
                        input_axes={lowrank_axis: rank},
                        output_axes=linear.output_axes,
                        parallel_axes=linear.parallel_axes,
                        initializer=pz.nn.zero_initializer,
                    ),
                ],
            ),
        ])
    ])

Note: Idiomatic Penzai layers generally avoid overriding __init__, since dataclasses take their attributes as parameters to __init__ and we want to ensure the output of the pretty-printer directly corresponds to code we could use to rebuild the model even if we’ve modified its attributes. When we have nontrivial construction logic, we’ll usually define it in a class method like from_linear or from_config instead.

Layer constructors are generally responsible for ensuring their parameter names are unique within a model, and for initializing their parameters when constructed. For this reason, most layer constructors take arguments name and init_base_rng. (Note that the name is combined with the RNG when initializing each parameter, so we don’t need to manually split the RNGs.)

The next step is to write a helper function for inserting LoRA blocks into a model. We’ll use Penzai’s pretty_keystr function (a fancier version of jax.tree_util.keystr) to ensure each block has a unique name:

def loraify_all_linears(model, rank: int, init_base_rng):
  return (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              name="LoRA:" + pz.pretty_keystr(keypath, model),
              init_base_rng=init_base_rng,
              rank=rank,
          ),
          with_keypath=True,
      )
  )

Now we can run it on our MLP:

loraified_mlp = loraify_all_linears(
    frozen_mlp, rank=2, init_base_rng=jax.random.key(42)
)
loraified_mlp

You can directly check that this transformation is doing the right thing by expanding each Affine layer and making sure the LowRankAdapter looks right.

Note that loraified_mlp_uninit is a copy of frozen_mlp with the requested modifications. In Penzai, transformations of models always return new copies of the model, so you don’t have to worry about accidentally making an irreversible change.

Only the model structure is copied; the JAX arrays still share memory between the models, and any mutable parameters in the original model will also be shared with the new one. In this case, though, we froze the parameters of frozen_mlp first, so only the new parameters are mutable:

pz.select(loraified_mlp).at_instances_of(pz.Parameter).get_sequence()

Step 3: Training the LoRA weights#

We can now train these adapter parameters using Penzai’s basic training loop helpers, or use a custom training loop for them. As a demonstration, we’ll train this model to implement XOR by only fitting the low-rank adapter parameters.

def loss_fn(model, rng, state, example_inputs, example_labels):
  assert state is None
  model_out = model(example_inputs)
  log_probs = jax.nn.log_softmax(
      model_out.unwrap("batch", "features"), axis=-1
  )
  losses = -log_probs * example_labels
  loss = jnp.sum(losses) / 4
  return loss, None, {"loss": loss}
trainer = basic_training.StatefulTrainer.build(
    model=loraified_mlp,
    optimizer_def=optax.adam(0.1),
    root_rng=jax.random.key(42),
    loss_fn=loss_fn
)
trainer
xor_inputs = pz.nx.wrap(
    jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
    "batch",
    "features",
)
xor_labels = jnp.array([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32)

for i in range(20):
  out = trainer.step(example_inputs=xor_inputs, example_labels=xor_labels)
  print(i, out)

The parameters in the model will be updated in place, allowing us to use the trained model:

%%autovisualize
loraified_mlp(xor_inputs)
%%autovisualize
pz.nx.nmap(jnp.argmax)(loraified_mlp(xor_inputs).untag("features"))

Looks like it worked!

(Note: If you prefer a “functional” training loop, you can extract an immutable version of your parameters by calling pz.unbind_params(loraified_mlp, frozen=True), update them yourself, then substitute the immutable parameters back in using pz.bind_variables.)

Adding LoRA to Gemma#

Let’s now try adding LoRA to the Gemma 7B pretrained model. Because of Penzai’s compositional design, the implementation in the previous section will just work out of the box!

Loading Gemma#

We’ll start by loading the weights from the Gemma checkpoint. We’ll use the 7B checkpoint for this tutorial, and shard it over our local devices using JAX’s automatic partitioning. (You can read more about JAX’s automatic distributed arrays on this JAX documentation page.)

If you prefer, you can also run this tutorial with the 2B checkpoint.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don’t have an API key already, you can:

  1. Visit https://www.kaggle.com/ and create an account if needed.

  2. Go to your account settings, then the ‘API’ section.

  3. Click ‘Create new token’ to download your key.

Next, if you are running this notebook in Google Colab:

  1. Click the “key” symbol on the left toolbar to open the “Secrets” tab.

  2. Add two new secrets, named “KAGGLE_USERNAME” and “KAGGLE_KEY”, and set their values based on the API key you downloaded.

  3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

Kaggle credentials set.

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven’t done that yet, you can do so here:

https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to “Verify via Kaggle Account” with the same account you used to log in above!)

Once you’ve agreed to the terms, you can run the next cell to download the Gemma weights:

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using orbax:

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)
gemma_model = transformer.variants.gemma.gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=True,
)
del flat_params
gc.collect()

Here’s what the Gemma model looks like:

%%autovisualize
gemma_model

Try clicking the triangle markers to explore the structure of Gemma and look at some of the parameters!

Converting Gemma#

Now we can freeze its parameters and LoRA-ify its linear blocks in the same way that we did for the simple MLP.

The Penzai implementation of Gemma uses the same Linear layer to implement all of the learnable operations, in both the MLP blocks and the attention blocks. So we’ll use a slightly-modified helper function that lets us be more specific about which Linear layers we want to replace.

def loraify_linears_in_selection(
    selection, rank: int, init_base_rng: jax.Array | None,
):
  model = selection.deselect()
  return selection.at_instances_of(pz.nn.Linear).apply(
      lambda keypath, lin: LowRankAdapter.from_linear(
          lin,
          name="LoRA:" + pz.pretty_keystr(keypath, model),
          init_base_rng=init_base_rng,
          rank=rank,
      ),
      with_keypath=True,
  )

Now we go through and apply each of the transformation steps:

# Step 1: Freeze the pretrained parameters.
frozen_gemma_model = pz.freeze_params(gemma_model)
# Step 2: LoRA-ify the Linear blocks. Following Hu et al. (2021), we'll only
# LoRA-ify the attention parameters.
loraified_gemma_model = loraify_linears_in_selection(
    pz.select(frozen_gemma_model).at_instances_of(pz.nn.Attention),
    rank=16,
    init_base_rng=jax.random.key(123),
)
# Step 3 (optional): Look at it to make sure the transformation looks right.
pz.select(loraified_gemma_model).at_instances_of(LowRankAdapter).show_selection()