KVCachingAttention#
- class penzai.nn.attention.KVCachingAttention[source]#
Bases:
LayerKey/value caching variant of
Attention.KVCachingAttentionis a drop-in replacement forAttention, but adds key/value caching logic using Penzai’s effect system. This means that a model initially configured for training can be quickly adapted to do inference without making the training logic more complicated.- Variables:
input_to_query (layer_base.Layer) – A layer that maps the input to an array of queries, usually taken from the original
Attentionlayer.input_to_key (layer_base.Layer) – A layer that maps the input to an array of keys, usually taken from the original
Attentionlayer. The output of this layer will additionally be stored in the stateful key/value cache.input_to_value (layer_base.Layer) – A layer that maps the input to an array of values, usually taken from the original
Attentionlayer. The output of this layer will additionally be stored in the stateful key/value cache.query_key_to_attn (layer_base.Layer) – A layer that maps a tuple of
(queries, keys)to attention weights, usually taken from the originalAttentionlayer. The key input will contain the full key cache, rather than the slice produced for the current token.attn_value_to_output (layer_base.Layer) – A layer that maps a a tuple of
(attention weights, values)to a final output, usually taken from the originalAttentionlayer. The value input will contain the full value cache, rather than the slice produced for the current token.sequence_axis (str) – The axis along which to do key/value caching. Should be an axis name that appears in the output of the
input_to_keyandinput_to_valuesublayers.kv_cache_end_index_key (Hashable) – The key for the side input that identifies the current dynamic size of the key/value caches, i.e. the number of elements that have been populated with entries. The corresponding side input should be a scalar integer array.
kv_cache (variables.StateVariable[tuple[named_axes.NamedArray, named_axes.NamedArray]]) – A state effect variable that stores a tuple of key and value caches. This will be initialized when this layer is constructed, and will be updated as it runs.
Methods
__init__(input_to_query, input_to_key, ...)from_uncached(original, sequence_axis, ...)Builds a caching attention from an uncached attention.
__call__(x, **side_inputs)Runs the caching attention computation and update the K/V cache state.
Attributes
input_to_queryinput_to_keyinput_to_valuequery_key_to_attnattn_value_to_outputsequence_axiskv_cache_end_index_keykv_cacheInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
bind_variables(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray[source]#
Runs the caching attention computation and update the K/V cache state.
When called,
self.kv_cache_end_indexshould be filled with a scalar integer identifying the current size of the cache (before inserting this token), andself.kv_cacheshould be aLocalStatethat contains the current state.- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
**side_inputs – Side inputs for all sublayers. Should contain the key-value cache end index at the key indicated by this layer’s
kv_cache_end_index_keyattribute.
- Returns:
The final output of the
attn_value_to_outputsublayer.
- classmethod from_uncached(original: Attention, sequence_axis: str, cache_len: int, cached_axes: dict[str, int], cache_end_index_key: Hashable, cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, cache_label: variables.VariableLabel | None = None, layerstack_axes: dict[named_axes.AxisName, int] | None = None) KVCachingAttention[source]#
Builds a caching attention from an uncached attention.
- Parameters:
original – The original attention layer that this block should replace.
sequence_axis – The axis along which keys and values should be cached. Should be present in the output of the
input_to_keyandinput_to_valuesublayers.cache_len – Length of the cache; used to populate the initial state.
cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.
cache_end_index_key – Key to use for the cache position side input.
cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.
cache_label – Optional label for the KV cache variable.
layerstack_axes – Stacked axes that are used inside a LayerStack combinator. Usually inferred from
pz.nn.layerstack_axes_from_keypath.
- Returns:
A
KVCachingAttentioninstance that behaves like the originalAttentionlayer, but updates key-value caches iteratively, using new side input and state effect requests.