KVCachingAttention#
- class penzai.nn.attention.KVCachingAttention[source]#
Bases:
Layer
Key/value caching variant of
Attention
.KVCachingAttention
is a drop-in replacement forAttention
, but adds key/value caching logic using Penzai’s effect system. This means that a model initially configured for training can be quickly adapted to do inference without making the training logic more complicated.- Variables:
input_to_query (layer_base.Layer) – A layer that maps the input to an array of queries, usually taken from the original
Attention
layer.input_to_key (layer_base.Layer) – A layer that maps the input to an array of keys, usually taken from the original
Attention
layer. The output of this layer will additionally be stored in the stateful key/value cache.input_to_value (layer_base.Layer) – A layer that maps the input to an array of values, usually taken from the original
Attention
layer. The output of this layer will additionally be stored in the stateful key/value cache.query_key_to_attn (layer_base.Layer) – A layer that maps a tuple of
(queries, keys)
to attention weights, usually taken from the originalAttention
layer. The key input will contain the full key cache, rather than the slice produced for the current token.attn_value_to_output (layer_base.Layer) – A layer that maps a a tuple of
(attention weights, values)
to a final output, usually taken from the originalAttention
layer. The value input will contain the full value cache, rather than the slice produced for the current token.sequence_axis (str) – The axis along which to do key/value caching. Should be an axis name that appears in the output of the
input_to_key
andinput_to_value
sublayers.kv_cache_end_index_key (Hashable) – The key for the side input that identifies the current dynamic size of the key/value caches, i.e. the number of elements that have been populated with entries. The corresponding side input should be a scalar integer array.
kv_cache (variables.StateVariable[tuple[named_axes.NamedArray, named_axes.NamedArray]]) – A state effect variable that stores a tuple of key and value caches. This will be initialized when this layer is constructed, and will be updated as it runs.
Methods
__init__
(input_to_query, input_to_key, ...)from_uncached
(original, sequence_axis, ...)Builds a caching attention from an uncached attention.
__call__
(x, **side_inputs)Runs the caching attention computation and update the K/V cache state.
Attributes
input_to_query
input_to_key
input_to_value
query_key_to_attn
attn_value_to_output
sequence_axis
kv_cache_end_index_key
kv_cache
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray [source]#
Runs the caching attention computation and update the K/V cache state.
When called,
self.kv_cache_end_index
should be filled with a scalar integer identifying the current size of the cache (before inserting this token), andself.kv_cache
should be aLocalState
that contains the current state.- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
**side_inputs – Side inputs for all sublayers. Should contain the key-value cache end index at the key indicated by this layer’s
kv_cache_end_index_key
attribute.
- Returns:
The final output of the
attn_value_to_output
sublayer.
- classmethod from_uncached(original: Attention, sequence_axis: str, cache_len: int, cached_axes: dict[str, int], cache_end_index_key: Hashable, cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, cache_label: variables.VariableLabel | None = None, layerstack_axes: dict[named_axes.AxisName, int] | None = None) KVCachingAttention [source]#
Builds a caching attention from an uncached attention.
- Parameters:
original – The original attention layer that this block should replace.
sequence_axis – The axis along which keys and values should be cached. Should be present in the output of the
input_to_key
andinput_to_value
sublayers.cache_len – Length of the cache; used to populate the initial state.
cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.
cache_end_index_key – Key to use for the cache position side input.
cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.
cache_label – Optional label for the KV cache variable.
layerstack_axes – Stacked axes that are used inside a LayerStack combinator. Usually inferred from
pz.nn.layerstack_axes_from_keypath
.
- Returns:
A
KVCachingAttention
instance that behaves like the originalAttention
layer, but updates key-value caches iteratively, using new side input and state effect requests.