loraify_linears_in_selection

loraify_linears_in_selection#

penzai.toolshed.lora.loraify_linears_in_selection(selection: pz.Selection[Any], rank: int, init_base_rng: jax.Array | None) Any[source]#

Replaces Linear layers inside a selected part of a model with LoRA blocks.

This function should usually be called after freezing the existing weights in the model using something like

pz.nn.at_instances_of(pz.nn.Parameter).apply(
    lambda param: pz.nn.FrozenParameter(param.value, param.name)
)

This function returns a copy of the model with new LoRA parameters added, but does not modify any existing parameters.

Parameters:
  • selection – A selection of a model that identifies the parts for which LoRA adaptation should be applied. Any Linear layers contained within the selected part will be replaced.

  • rank – The rank of the LoRA blocks to insert.

  • init_base_rng – The base RNG to use for initializing the LoRA parameters.

Returns:

A copy of the original full model (e.g. of selection.deselect()), but where each of the Linear layers inside the selected part are replaced with new LowRankAdapter instances.