KVCachingTransformerLM#
- class penzai.models.transformer.sampling_mode.KVCachingTransformerLM[source]#
Bases:
Layer
Top-level transformer in (stateful) cached autoregressive sampling mode.
This class represents the sampling mode of the model, and manages the sampling state. It is designed to be loaded from an existing
Transformer
. If you want to load this from the pretrained checkpoint, first load aTransformer
, then callKVCachingTransformer.from_uncached
.This class handles and automatically increments token positions based on the tokens it has generated so far.
- Variables:
body (pz.nn.Layer) – The implementation of the transformer. Usually a nested set of state and side-effect handlers wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.
cache_end_index (pz.StateVariable[int]) – A variable containing the current end index of the caches.
previous_tokens (pz.StateVariable[pz.nx.NamedArray]) – A variable containing all previously-seen tokens.
metadata (model_parts.TransformerMetadata) – The configuration for the transformer.
cache_len (int) – The length of the internal key-value caches.
batch_axes (dict[str, int]) – The batch axes of the internal key-value caches.
pad_id (int) – Token ID that indicates padding.
Methods
__init__
(body, cache_end_index, ...)from_uncached
(uncached, cache_len, batch_axes)Transforms a
Transformer
into cached sampling mode.__call__
(tokens, **extra_side_inputs)Processes a new subsequence of tokens and adds them to the K/V cache.
Attributes
body
cache_end_index
previous_tokens
metadata
cache_len
batch_axes
pad_id
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(tokens: pz.nx.NamedArray, **extra_side_inputs: dict[Any, Any]) pz.nx.NamedArray [source]#
Processes a new subsequence of tokens and adds them to the K/V cache.
When called, the internal variables tracking the key-value cache will be updated with the new state.
- Parameters:
tokens – Array of token IDs, as an integer named array with a “seq” axis and possibly batch axes. The batch axes must match the
batch_axes
attribute. Padding tokens are ignored.**extra_side_inputs – Extra side inputs, which will be forwarded on to the body. The “token_positions”, “kv_token_positions”, and “cache_end_index” inputs will be added automatically and do not need to be provided.
- Returns:
Matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.
- classmethod from_uncached(uncached: model_parts.TransformerLM, cache_len: int, batch_axes: dict[str, int], pad_id: int = 0, variable_name_prefix: str = 'sampler') KVCachingTransformerLM [source]#
Transforms a
Transformer
into cached sampling mode.This constructor hot-swaps all
pz.nn.Attention
layers in the original model to enable key-value caching, then installs new handlers to update their states appropriately. Note that any modifications to the uncached model will persist in the decoding mode.- Parameters:
uncached – The original
Transformer
model.cache_len – Maximum sequence length for the key/value caches.
batch_axes – Names and sizes for the batch axes that will be used for sampling. Required for initializing the key/value caches.
pad_id – ID for the padding token.
variable_name_prefix – Prefix for cached sampling variable names.
- Returns:
A KVCachingTransformer.