KVCachingTransformerLM#
- class penzai.models.transformer.sampling_mode.KVCachingTransformerLM[source]#
Bases:
LayerTop-level transformer in (stateful) cached autoregressive sampling mode.
This class represents the sampling mode of the model, and manages the sampling state. It is designed to be loaded from an existing
Transformer. If you want to load this from the pretrained checkpoint, first load aTransformer, then callKVCachingTransformer.from_uncached.This class handles and automatically increments token positions based on the tokens it has generated so far.
- Variables:
body (pz.nn.Layer) – The implementation of the transformer. Usually a nested set of state and side-effect handlers wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.
cache_end_index (pz.StateVariable[int]) – A variable containing the current end index of the caches.
previous_tokens (pz.StateVariable[pz.nx.NamedArray]) – A variable containing all previously-seen tokens.
metadata (model_parts.TransformerMetadata) – The configuration for the transformer.
cache_len (int) – The length of the internal key-value caches.
batch_axes (dict[str, int]) – The batch axes of the internal key-value caches.
pad_id (int) – Token ID that indicates padding.
Methods
__init__(body, cache_end_index, ...)from_uncached(uncached, cache_len, batch_axes)Transforms a
Transformerinto cached sampling mode.__call__(tokens, **extra_side_inputs)Processes a new subsequence of tokens and adds them to the K/V cache.
Attributes
bodycache_end_indexprevious_tokensmetadatacache_lenbatch_axespad_idInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
bind_variables(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(tokens: pz.nx.NamedArray, **extra_side_inputs: dict[Any, Any]) pz.nx.NamedArray[source]#
Processes a new subsequence of tokens and adds them to the K/V cache.
When called, the internal variables tracking the key-value cache will be updated with the new state.
- Parameters:
tokens – Array of token IDs, as an integer named array with a “seq” axis and possibly batch axes. The batch axes must match the
batch_axesattribute. Padding tokens are ignored.**extra_side_inputs – Extra side inputs, which will be forwarded on to the body. The “token_positions”, “kv_token_positions”, and “cache_end_index” inputs will be added automatically and do not need to be provided.
- Returns:
Matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.
- classmethod from_uncached(uncached: model_parts.TransformerLM, cache_len: int, batch_axes: dict[str, int], pad_id: int = 0, variable_name_prefix: str = 'sampler') KVCachingTransformerLM[source]#
Transforms a
Transformerinto cached sampling mode.This constructor hot-swaps all
pz.nn.Attentionlayers in the original model to enable key-value caching, then installs new handlers to update their states appropriately. Note that any modifications to the uncached model will persist in the decoding mode.- Parameters:
uncached – The original
Transformermodel.cache_len – Maximum sequence length for the key/value caches.
batch_axes – Names and sizes for the batch axes that will be used for sampling. Required for initializing the key/value caches.
pad_id – ID for the padding token.
variable_name_prefix – Prefix for cached sampling variable names.
- Returns:
A KVCachingTransformer.