loraify_linears_in_selection#
- penzai.toolshed.lora.loraify_linears_in_selection(selection: pz.Selection[Any], rank: int, init_base_rng: jax.Array | None) Any [source]#
Replaces Linear layers inside a selected part of a model with LoRA blocks.
This function should usually be called after freezing the existing weights in the model using something like
pz.nn.at_instances_of(pz.nn.Parameter).apply( lambda param: pz.nn.FrozenParameter(param.value, param.name) )
This function returns a copy of the model with new LoRA parameters added, but does not modify any existing parameters.
- Parameters:
selection – A selection of a model that identifies the parts for which LoRA adaptation should be applied. Any
Linear
layers contained within the selected part will be replaced.rank – The rank of the LoRA blocks to insert.
init_base_rng – The base RNG to use for initializing the LoRA parameters.
- Returns:
A copy of the original full model (e.g. of
selection.deselect()
), but where each of theLinear
layers inside the selected part are replaced with newLowRankAdapter
instances.