GemmaKVCachingTransformer#

class penzai.example_models.gemma.sampling_mode.GemmaKVCachingTransformer[source]#

Bases: Layer

Top-level Gemma transformer in cached autoregressive sampling mode.

This class represents the sampling mode of the Gemma model, and is designed to be loaded from an existing GemmaTransformer. If you want to load this from the pretrained checkpoint, first load a GemmaTransformer, then call GemmaKVCachingTransformer.from_uncached.

Variables:
  • config (model_core.GemmaTransformerConfig) – The configuration for the transformer.

  • body (pz.LayerLike) – The implementation of the transformer. Usually a nested set of state and side-effect handlers wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.

Methods

__init__(config, body)

from_uncached(uncached, cache_len, batch_axes)

Transforms a GemmaTransformer into cached sampling mode.

input_structure()

output_structure()

__call__(inputs)

Processes a new subsequence of tokens and adds them to the K/V cache.

Attributes

config

body

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(inputs: GemmaKVCachingInputs) tuple[pz.nx.NamedArray, GemmaKVCachingState][source]#

Processes a new subsequence of tokens and adds them to the K/V cache.

Parameters:

inputs – Structure of input arguments, containing tokens, segment positions, an attention mask, and the current sampling state.

Returns:

A tuple (outputs, new_sampling_state), whre outputs is the final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”, and new_sampling_state is the updated sampling state with the updated key-value caches.

classmethod from_uncached(uncached: model_core.GemmaTransformer, cache_len: int, batch_axes: dict[str, int]) tuple[GemmaKVCachingTransformer, GemmaKVCachingState][source]#

Transforms a GemmaTransformer into cached sampling mode.

This constructor hot-swaps all model_core.GemmaAttention layers in the original model to enable key-value caching, then installs new handlers to update their states appropriately. Note that any modifications to the uncached model will persist in the decoding mode.

Parameters:
  • uncached – The original GemmaTransformer model.

  • cache_len – Maximum sequence length for the key/value caches.

  • batch_axes – Names and sizes for the batch axes that will be used for sampling. Required for initializing the key/value caches.

Returns:

Tuple (sampler_model, initial_sampling_state), where sampler_model is a GemmaKVCachingTransformer, and initial_sampling_state holds the initial empty key/value caches.