GemmaKVCachingAttention#

class penzai.example_models.gemma.sampling_mode.GemmaKVCachingAttention[source]#

Bases: KVCachingAttention

Gemma-specific configuration of the key-value-caching attention layer.

GemmaKVCachingAttention has the same runtime behavior as the base pz.nn.KVCachingAttention layer, but is specialized to the conventions of the Gemma model.

Methods

__init__(input_to_query, input_to_key, ...)

from_uncached(original, cache_len, cached_axes)

Builds a caching attention from an uncached attention.

Attributes

input_to_query

input_to_key

input_to_value

query_key_to_attn

attn_value_to_output

sequence_axis

kv_cache_end_index

kv_cache

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x)

Runs the caching attention computation and update the K/V cache state.

classmethod from_uncached(original: model_core.GemmaAttention, cache_len: int, cached_axes: dict[str, int], cache_dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>) GemmaKVCachingAttention[source]#

Builds a caching attention from an uncached attention.

Parameters:
  • original – The original attention layer that this block should replace.

  • cache_len – Length of the cache; used to populate the initial state.

  • cached_axes – Axis names and sizes for all other axes of the key and value arrays (e.g. for batch, heads, and the projected embeddings). These are used to initialize the cache.

  • cache_dtype – Dtype for the data to store in the cache. Should match the dtype of the key and value arrays.

Returns:

A GemmaKVCachingAttention instance that behaves like the original Attention layer, but updates key-value caches iteratively, using new side input and state effect requests.