SamplingState

Contents

SamplingState#

class penzai.example_models.gemma.simple_decoding_loop.SamplingState[source]#

Bases: Struct

State that manages the set decoded tokens during sampling.

The purpose of this class is to keep the outputs, inputs, and key-value caches of the transformer in sync during decoding, even in the presence of padding tokens. This makes it possible to sample from batches of prompts which have different lengths.

Padding tokens are treated as “invisible” to the model; they will be skipped over and masked out from the model’s queries if they appear in the middle of the sequence.

Variables:
  • kv_cache_state (sampling_mode.GemmaKVCachingState) – The state of the KV caches.

  • previous_tokens (pz.nx.NamedArray) – An array of all of the outputs that have been written so far. This is also used to identify the effective sequence position of the next token by counting the non-padding tokens.

  • pad_id (int) – Token ID that indicates padding.

Methods

__init__(kv_cache_state, previous_tokens, pad_id)

Attributes

kv_cache_state

previous_tokens

pad_id

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.