LowRankAdapter#

class penzai.toolshed.lora.LowRankAdapter[source]#

Bases: Sequential

A LoRA parameter-efficient adaptation block, replacing a Linear layer.

Methods

__init__(sublayers)

from_linear(linear, rank, name[, lowrank_axis])

Builds a LoRA layer from a pz.nn.Linear layer.

Attributes

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

__call__(value)

Runs each of the sublayers in sequence.

classmethod from_linear(linear: pz.nn.Linear, rank: int, name: str, lowrank_axis: str = 'lowrank') LowRankAdapter[source]#

Builds a LoRA layer from a pz.nn.Linear layer.

Parameters:
  • linear – The linear layer to adapt.

  • rank – The rank of the low-rank adapter.

  • name – Prefix for this block’s parameters. Must be globally unique across all LoRA blocks; we recommend using jax.tree_util.keystr or pz.pretty_keystr and setting the name based on the path to the original Linear layer being replaced.

  • lowrank_axis – The axis name for low-rank adaptation.

Returns:

A LoRA block with uninitialized parameters and the same initial behavior as linear.