GemmaKVCachingAttention#
- class penzai.deprecated.v1.example_models.gemma.sampling_mode.GemmaKVCachingAttention[source]#
Bases:
KVCachingAttention
Gemma-specific configuration of the key-value-caching attention layer.
GemmaKVCachingAttention
has the same runtime behavior as the basepz.nn.KVCachingAttention
layer, but is specialized to the conventions of the Gemma model.Methods
__init__
(input_to_query, input_to_key, ...)from_uncached
(original, cache_len, cached_axes)Builds a caching attention from an uncached attention.
Attributes
input_to_query
input_to_key
input_to_value
query_key_to_attn
attn_value_to_output
sequence_axis
kv_cache_end_index
kv_cache
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
input_structure
()Returns the input structure of this layer.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
output_structure
()Returns the output structure of this layer.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
__call__
(x)Runs the caching attention computation and update the K/V cache state.
- classmethod from_uncached(original: model_core.GemmaAttention, cache_len: int, cached_axes: dict[str, int], cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>) GemmaKVCachingAttention [source]#
Builds a caching attention from an uncached attention.
- Parameters:
original – The original attention layer that this block should replace.
cache_len – Length of the cache; used to populate the initial state.
cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.
cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.
- Returns:
A
GemmaKVCachingAttention
instance that behaves like the originalAttention
layer, but updates key-value caches iteratively, using new side input and state effect requests.