KVCachingAttention#
- class penzai.deprecated.v1.nn.attention.KVCachingAttention[source]#
Bases:
LayerKey/value caching variant of
Attention.KVCachingAttentionis a drop-in replacement forAttention, but adds key/value caching logic using Penzai’s effect system. This means that a model initially configured for training can be quickly adapted to do inference without making the training logic more complicated.- Variables:
input_to_query (layer_base.LayerLike) – A layer that maps the input to an array of queries, usually taken from the original
Attentionlayer.input_to_key (layer_base.LayerLike) – A layer that maps the input to an array of keys, usually taken from the original
Attentionlayer. The output of this layer will additionally be stored in the stateful key/value cache.input_to_value (layer_base.LayerLike) – A layer that maps the input to an array of values, usually taken from the original
Attentionlayer. The output of this layer will additionally be stored in the stateful key/value cache.query_key_to_attn (layer_base.LayerLike) – A layer that maps a tuple of
(queries, keys)to attention weights, usually taken from the originalAttentionlayer. The key input will contain the full key cache, rather than the slice produced for the current token.attn_value_to_output (layer_base.LayerLike) – A layer that maps a a tuple of
(attention weights, values)to a final output, usually taken from the originalAttentionlayer. The value input will contain the full value cache, rather than the slice produced for the current token.sequence_axis (str) – The axis along which to do key/value caching. Should be an axis name that appears in the output of the
input_to_keyandinput_to_valuesublayers.kv_cache_end_index (side_input.SideInputEffect[jax.Array]) – A side input that identifies the current dynamic size of the key/value caches, i.e. the number of elements that have been populated with entries. Should be populated by a scalar integer array.
kv_cache (local_state.LocalStateEffect[tuple[named_axes.NamedArray, named_axes.NamedArray]]) – A state effect variable that stores a tuple of key and value caches. This will be initialized when this layer is constructed, and will be updated as it runs.
Methods
__init__(input_to_query, input_to_key, ...)from_uncached(original, sequence_axis, ...)Builds a caching attention from an uncached attention.
__call__(x)Runs the caching attention computation and update the K/V cache state.
Attributes
input_to_queryinput_to_keyinput_to_valuequery_key_to_attnattn_value_to_outputsequence_axiskv_cache_end_indexkv_cacheInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
input_structure()Returns the input structure of this layer.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
output_structure()Returns the output structure of this layer.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray) named_axes.NamedArray[source]#
Runs the caching attention computation and update the K/V cache state.
When called,
self.kv_cache_end_indexshould be filled with a scalar integer identifying the current size of the cache (before inserting this token), andself.kv_cacheshould be aLocalStatethat contains the current state.- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
- Returns:
The final output of the
attn_value_to_outputsublayer.
- classmethod from_uncached(original: Attention, sequence_axis: str, cache_len: int, cached_axes: dict[str, int], cache_end_index_tag: side_input.Tag, state_category: local_state.Category, cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>) KVCachingAttention[source]#
Builds a caching attention from an uncached attention.
- Parameters:
original – The original attention layer that this block should replace.
sequence_axis – The axis along which keys and values should be cached. Should be present in the output of the
input_to_keyandinput_to_valuesublayers.cache_len – Length of the cache; used to populate the initial state.
cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.
cache_end_index_tag – Side input tag for the cache position side input. This should be used to identify the side inputs that should receive the cache position information, and should (usually) be provided to the
pz.de.WithSideInputsFromInputTuplehandler that actually provides this side input.state_category – Category for the local state. This should be used to identify the state variables that correspond to key-value caches in the model, and should (usually) be provided to the
pz.de.handle_local_statescall that functionalizes the state effect.cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.
- Returns:
A
KVCachingAttentioninstance that behaves like the originalAttentionlayer, but updates key-value caches iteratively, using new side input and state effect requests.