GemmaKVCachingInputs#

class penzai.example_models.gemma.sampling_mode.GemmaKVCachingInputs[source]#

Bases: Struct

Input structure for the GemmaKVCachingTransformer.

Variables:
  • tokens (pz.nx.NamedArray) – Subsequence of the current tokens we are processing, as an integer named array with a “seq” axis and possibly batch axes. When pre-filling, this can be the length of the prompt. When sampling, the “seq” instance will usually have length 1.

  • positions (pz.nx.NamedArray) – Sequence of current token positions, as an integer named array with a “seq” axis and possibly batch axes, of the same sequence length as tokens. Should usually increase with each call to the transformer.

  • attention_mask (pz.nx.NamedArray) – Boolean attention mask with “seq” and “kv_seq” axes and possibly batch axes. The “seq” axis should match tokens and positions, and the “kv_seq” axis should match the cache_len of the sampling_state. Usually a slice of the causal mask.

  • sampling_state (GemmaKVCachingState) – Current sampling state, containing key-value caches.

Methods

__init__(tokens, positions, attention_mask, ...)

from_basic_subsegments(tokens, sampling_state)

Constructs a simple input structure for a batch of unpadded samples.

Attributes

tokens

positions

attention_mask

sampling_state

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

classmethod from_basic_subsegments(tokens: pz.nx.NamedArray, sampling_state: GemmaKVCachingState) GemmaKVCachingInputs[source]#

Constructs a simple input structure for a batch of unpadded samples.

This can be used to process inputs that do not need advanced position or attention mask handling, and which just consist of ordinary sequences that are not packed together or padded. It augments the tokens with a standard position array and causal attention mask, adjusted by the current cache offset.

Parameters:
  • tokens – Subsequence of tokens, as an integer named array with a “seq” axis and possibly batch axes. When pre-filling, the “seq” axis can be the length of the prompt. When sampling, the “seq” instance will usually have length 1.

  • sampling_state – Current sampling state, containing key-value caches.

Returns:

A full input structure containing the provided tokens, along with a simple incrementing position array and a causal mask, offset by the current sampling state.