LowRankAdapter#

class penzai.toolshed.lora.LowRankAdapter[source]#

Bases: Sequential

A LoRA parameter-efficient adaptation block, replacing a Linear layer.

Methods

__init__(sublayers)

from_linear(linear, name, init_base_rng, rank)

Builds a LoRA layer from a pz.nn.Linear layer.

Attributes

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

__call__(value, **side_inputs)

Runs each of the sublayers in sequence.

classmethod from_linear(linear: pz.nn.Linear, name: str, init_base_rng: jax.Array | None, rank: int, lowrank_axis: str = 'lowrank') LowRankAdapter[source]#

Builds a LoRA layer from a pz.nn.Linear layer.

Parameters:
  • linear – The linear layer to adapt.

  • name – Name for this layer’s parameters. Must be globally unique across all LoRA blocks; we recommend using jax.tree_util.keystr or pz.pretty_keystr and setting the name based on the path to the original Linear layer being replaced.

  • init_base_rng – The base RNG to use for initializing model parameters.

  • rank – The rank of the low-rank adapter.

  • lowrank_axis – The axis name for low-rank adaptation.

Returns:

A LoRA block with uninitialized parameters and the same initial behavior as linear.