KVCachingAttention#
- class penzai.deprecated.v1.nn.attention.KVCachingAttention[source]#
Bases:
Layer
Key/value caching variant of
Attention
.KVCachingAttention
is a drop-in replacement forAttention
, but adds key/value caching logic using Penzai’s effect system. This means that a model initially configured for training can be quickly adapted to do inference without making the training logic more complicated.- Variables:
input_to_query (layer_base.LayerLike) – A layer that maps the input to an array of queries, usually taken from the original
Attention
layer.input_to_key (layer_base.LayerLike) – A layer that maps the input to an array of keys, usually taken from the original
Attention
layer. The output of this layer will additionally be stored in the stateful key/value cache.input_to_value (layer_base.LayerLike) – A layer that maps the input to an array of values, usually taken from the original
Attention
layer. The output of this layer will additionally be stored in the stateful key/value cache.query_key_to_attn (layer_base.LayerLike) – A layer that maps a tuple of
(queries, keys)
to attention weights, usually taken from the originalAttention
layer. The key input will contain the full key cache, rather than the slice produced for the current token.attn_value_to_output (layer_base.LayerLike) – A layer that maps a a tuple of
(attention weights, values)
to a final output, usually taken from the originalAttention
layer. The value input will contain the full value cache, rather than the slice produced for the current token.sequence_axis (str) – The axis along which to do key/value caching. Should be an axis name that appears in the output of the
input_to_key
andinput_to_value
sublayers.kv_cache_end_index (side_input.SideInputEffect[jax.Array]) – A side input that identifies the current dynamic size of the key/value caches, i.e. the number of elements that have been populated with entries. Should be populated by a scalar integer array.
kv_cache (local_state.LocalStateEffect[tuple[named_axes.NamedArray, named_axes.NamedArray]]) – A state effect variable that stores a tuple of key and value caches. This will be initialized when this layer is constructed, and will be updated as it runs.
Methods
__init__
(input_to_query, input_to_key, ...)from_uncached
(original, sequence_axis, ...)Builds a caching attention from an uncached attention.
__call__
(x)Runs the caching attention computation and update the K/V cache state.
Attributes
input_to_query
input_to_key
input_to_value
query_key_to_attn
attn_value_to_output
sequence_axis
kv_cache_end_index
kv_cache
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
input_structure
()Returns the input structure of this layer.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
output_structure
()Returns the output structure of this layer.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray) named_axes.NamedArray [source]#
Runs the caching attention computation and update the K/V cache state.
When called,
self.kv_cache_end_index
should be filled with a scalar integer identifying the current size of the cache (before inserting this token), andself.kv_cache
should be aLocalState
that contains the current state.- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
- Returns:
The final output of the
attn_value_to_output
sublayer.
- classmethod from_uncached(original: Attention, sequence_axis: str, cache_len: int, cached_axes: dict[str, int], cache_end_index_tag: side_input.Tag, state_category: local_state.Category, cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>) KVCachingAttention [source]#
Builds a caching attention from an uncached attention.
- Parameters:
original – The original attention layer that this block should replace.
sequence_axis – The axis along which keys and values should be cached. Should be present in the output of the
input_to_key
andinput_to_value
sublayers.cache_len – Length of the cache; used to populate the initial state.
cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.
cache_end_index_tag – Side input tag for the cache position side input. This should be used to identify the side inputs that should receive the cache position information, and should (usually) be provided to the
pz.de.WithSideInputsFromInputTuple
handler that actually provides this side input.state_category – Category for the local state. This should be used to identify the state variables that correspond to key-value caches in the model, and should (usually) be provided to the
pz.de.handle_local_states
call that functionalizes the state effect.cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.
- Returns:
A
KVCachingAttention
instance that behaves like the originalAttention
layer, but updates key-value caches iteratively, using new side input and state effect requests.