KVCachingAttention#

class penzai.nn.attention.KVCachingAttention[source]#

Bases: Layer

Key/value caching variant of Attention.

KVCachingAttention is a drop-in replacement for Attention, but adds key/value caching logic using Penzai’s effect system. This means that a model initially configured for training can be quickly adapted to do inference without making the training logic more complicated.

Variables:
  • input_to_query (layer_base.LayerLike) – A layer that maps the input to an array of queries, usually taken from the original Attention layer.

  • input_to_key (layer_base.LayerLike) – A layer that maps the input to an array of keys, usually taken from the original Attention layer. The output of this layer will additionally be stored in the stateful key/value cache.

  • input_to_value (layer_base.LayerLike) – A layer that maps the input to an array of values, usually taken from the original Attention layer. The output of this layer will additionally be stored in the stateful key/value cache.

  • query_key_to_attn (layer_base.LayerLike) – A layer that maps a tuple of (queries, keys) to attention weights, usually taken from the original Attention layer. The key input will contain the full key cache, rather than the slice produced for the current token.

  • attn_value_to_output (layer_base.LayerLike) – A layer that maps a a tuple of (attention weights, values) to a final output, usually taken from the original Attention layer. The value input will contain the full value cache, rather than the slice produced for the current token.

  • sequence_axis (str) – The axis along which to do key/value caching. Should be an axis name that appears in the output of the input_to_key and input_to_value sublayers.

  • kv_cache_end_index (side_input.SideInputEffect[jax.Array]) – A side input that identifies the current dynamic size of the key/value caches, i.e. the number of elements that have been populated with entries. Should be populated by a scalar integer array.

  • kv_cache (local_state.LocalStateEffect[tuple[named_axes.NamedArray, named_axes.NamedArray]]) – A state effect variable that stores a tuple of key and value caches. This will be initialized when this layer is constructed, and will be updated as it runs.

Methods

__init__(input_to_query, input_to_key, ...)

from_uncached(original, sequence_axis, ...)

Builds a caching attention from an uncached attention.

__call__(x)

Runs the caching attention computation and update the K/V cache state.

Attributes

input_to_query

input_to_key

input_to_value

query_key_to_attn

attn_value_to_output

sequence_axis

kv_cache_end_index

kv_cache

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x: named_axes.NamedArray) named_axes.NamedArray[source]#

Runs the caching attention computation and update the K/V cache state.

When called, self.kv_cache_end_index should be filled with a scalar integer identifying the current size of the cache (before inserting this token), and self.kv_cache should be a LocalState that contains the current state.

Parameters:

x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.

Returns:

The final output of the attn_value_to_output sublayer.

classmethod from_uncached(original: Attention, sequence_axis: str, cache_len: int, cached_axes: dict[str, int], cache_end_index_tag: side_input.Tag, state_category: local_state.Category, cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>) KVCachingAttention[source]#

Builds a caching attention from an uncached attention.

Parameters:
  • original – The original attention layer that this block should replace.

  • sequence_axis – The axis along which keys and values should be cached. Should be present in the output of the input_to_key and input_to_value sublayers.

  • cache_len – Length of the cache; used to populate the initial state.

  • cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.

  • cache_end_index_tag – Side input tag for the cache position side input. This should be used to identify the side inputs that should receive the cache position information, and should (usually) be provided to the pz.de.WithSideInputsFromInputTuple handler that actually provides this side input.

  • state_category – Category for the local state. This should be used to identify the state variables that correspond to key-value caches in the model, and should (usually) be provided to the pz.de.handle_local_states call that functionalizes the state effect.

  • cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.

Returns:

A KVCachingAttention instance that behaves like the original Attention layer, but updates key-value caches iteratively, using new side input and state effect requests.