KVCachingTransformerLM#

class penzai.models.transformer.sampling_mode.KVCachingTransformerLM[source]#

Bases: Layer

Top-level transformer in (stateful) cached autoregressive sampling mode.

This class represents the sampling mode of the model, and manages the sampling state. It is designed to be loaded from an existing Transformer. If you want to load this from the pretrained checkpoint, first load a Transformer, then call KVCachingTransformer.from_uncached.

This class handles and automatically increments token positions based on the tokens it has generated so far.

Variables:
  • body (pz.nn.Layer) – The implementation of the transformer. Usually a nested set of state and side-effect handlers wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.

  • cache_end_index (pz.StateVariable[int]) – A variable containing the current end index of the caches.

  • previous_tokens (pz.StateVariable[pz.nx.NamedArray]) – A variable containing all previously-seen tokens.

  • metadata (model_parts.TransformerMetadata) – The configuration for the transformer.

  • cache_len (int) – The length of the internal key-value caches.

  • batch_axes (dict[str, int]) – The batch axes of the internal key-value caches.

  • pad_id (int) – Token ID that indicates padding.

Methods

__init__(body, cache_end_index, ...)

from_uncached(uncached, cache_len, batch_axes)

Transforms a Transformer into cached sampling mode.

__call__(tokens, **extra_side_inputs)

Processes a new subsequence of tokens and adds them to the K/V cache.

Attributes

body

cache_end_index

previous_tokens

metadata

cache_len

batch_axes

pad_id

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(tokens: pz.nx.NamedArray, **extra_side_inputs: dict[Any, Any]) pz.nx.NamedArray[source]#

Processes a new subsequence of tokens and adds them to the K/V cache.

When called, the internal variables tracking the key-value cache will be updated with the new state.

Parameters:
  • tokens – Array of token IDs, as an integer named array with a “seq” axis and possibly batch axes. The batch axes must match the batch_axes attribute. Padding tokens are ignored.

  • **extra_side_inputs – Extra side inputs, which will be forwarded on to the body. The “token_positions”, “kv_token_positions”, and “cache_end_index” inputs will be added automatically and do not need to be provided.

Returns:

Matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.

classmethod from_uncached(uncached: model_parts.TransformerLM, cache_len: int, batch_axes: dict[str, int], pad_id: int = 0, variable_name_prefix: str = 'sampler') KVCachingTransformerLM[source]#

Transforms a Transformer into cached sampling mode.

This constructor hot-swaps all pz.nn.Attention layers in the original model to enable key-value caching, then installs new handlers to update their states appropriately. Note that any modifications to the uncached model will persist in the decoding mode.

Parameters:
  • uncached – The original Transformer model.

  • cache_len – Maximum sequence length for the key/value caches.

  • batch_axes – Names and sizes for the batch axes that will be used for sampling. Required for initializing the key/value caches.

  • pad_id – ID for the padding token.

  • variable_name_prefix – Prefix for cached sampling variable names.

Returns:

A KVCachingTransformer.