GemmaKVCachingTransformer#
- class penzai.deprecated.v1.example_models.gemma.sampling_mode.GemmaKVCachingTransformer[source]#
Bases:
LayerTop-level Gemma transformer in cached autoregressive sampling mode.
This class represents the sampling mode of the Gemma model, and is designed to be loaded from an existing
GemmaTransformer. If you want to load this from the pretrained checkpoint, first load aGemmaTransformer, then callGemmaKVCachingTransformer.from_uncached.- Variables:
config (model_core.GemmaTransformerConfig) – The configuration for the transformer.
body (pz.LayerLike) – The implementation of the transformer. Usually a nested set of state and side-effect handlers wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.
Methods
__init__(config, body)from_uncached(uncached, cache_len, batch_axes)Transforms a
GemmaTransformerinto cached sampling mode.input_structure()output_structure()__call__(inputs)Processes a new subsequence of tokens and adds them to the K/V cache.
Attributes
configbodyInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(inputs: GemmaKVCachingInputs) tuple[pz.nx.NamedArray, GemmaKVCachingState][source]#
Processes a new subsequence of tokens and adds them to the K/V cache.
- Parameters:
inputs – Structure of input arguments, containing tokens, segment positions, an attention mask, and the current sampling state.
- Returns:
A tuple
(outputs, new_sampling_state), whreoutputsis the final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”, andnew_sampling_stateis the updated sampling state with the updated key-value caches.
- classmethod from_uncached(uncached: model_core.GemmaTransformer, cache_len: int, batch_axes: dict[str, int]) tuple[GemmaKVCachingTransformer, GemmaKVCachingState][source]#
Transforms a
GemmaTransformerinto cached sampling mode.This constructor hot-swaps all
model_core.GemmaAttentionlayers in the original model to enable key-value caching, then installs new handlers to update their states appropriately. Note that any modifications to the uncached model will persist in the decoding mode.- Parameters:
uncached – The original
GemmaTransformermodel.cache_len – Maximum sequence length for the key/value caches.
batch_axes – Names and sizes for the batch axes that will be used for sampling. Required for initializing the key/value caches.
- Returns:
Tuple
(sampler_model, initial_sampling_state), wheresampler_modelis aGemmaKVCachingTransformer, andinitial_sampling_stateholds the initial empty key/value caches.