GemmaKVCachingTransformer#
- class penzai.deprecated.v1.example_models.gemma.sampling_mode.GemmaKVCachingTransformer[source]#
Bases:
Layer
Top-level Gemma transformer in cached autoregressive sampling mode.
This class represents the sampling mode of the Gemma model, and is designed to be loaded from an existing
GemmaTransformer
. If you want to load this from the pretrained checkpoint, first load aGemmaTransformer
, then callGemmaKVCachingTransformer.from_uncached
.- Variables:
config (model_core.GemmaTransformerConfig) – The configuration for the transformer.
body (pz.LayerLike) – The implementation of the transformer. Usually a nested set of state and side-effect handlers wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.
Methods
__init__
(config, body)from_uncached
(uncached, cache_len, batch_axes)Transforms a
GemmaTransformer
into cached sampling mode.input_structure
()output_structure
()__call__
(inputs)Processes a new subsequence of tokens and adds them to the K/V cache.
Attributes
config
body
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(inputs: GemmaKVCachingInputs) tuple[pz.nx.NamedArray, GemmaKVCachingState] [source]#
Processes a new subsequence of tokens and adds them to the K/V cache.
- Parameters:
inputs – Structure of input arguments, containing tokens, segment positions, an attention mask, and the current sampling state.
- Returns:
A tuple
(outputs, new_sampling_state)
, whreoutputs
is the final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”, andnew_sampling_state
is the updated sampling state with the updated key-value caches.
- classmethod from_uncached(uncached: model_core.GemmaTransformer, cache_len: int, batch_axes: dict[str, int]) tuple[GemmaKVCachingTransformer, GemmaKVCachingState] [source]#
Transforms a
GemmaTransformer
into cached sampling mode.This constructor hot-swaps all
model_core.GemmaAttention
layers in the original model to enable key-value caching, then installs new handlers to update their states appropriately. Note that any modifications to the uncached model will persist in the decoding mode.- Parameters:
uncached – The original
GemmaTransformer
model.cache_len – Maximum sequence length for the key/value caches.
batch_axes – Names and sizes for the batch axes that will be used for sampling. Required for initializing the key/value caches.
- Returns:
Tuple
(sampler_model, initial_sampling_state)
, wheresampler_model
is aGemmaKVCachingTransformer
, andinitial_sampling_state
holds the initial empty key/value caches.