SamplingState#
- class penzai.deprecated.v1.example_models.gemma.simple_decoding_loop.SamplingState[source]#
Bases:
Struct
State that manages the set decoded tokens during sampling.
The purpose of this class is to keep the outputs, inputs, and key-value caches of the transformer in sync during decoding, even in the presence of padding tokens. This makes it possible to sample from batches of prompts which have different lengths.
Padding tokens are treated as “invisible” to the model; they will be skipped over and masked out from the model’s queries if they appear in the middle of the sequence.
- Variables:
kv_cache_state (sampling_mode.GemmaKVCachingState) – The state of the KV caches.
previous_tokens (pz.nx.NamedArray) – An array of all of the outputs that have been written so far. This is also used to identify the effective sequence position of the next token by counting the non-padding tokens.
pad_id (int) – Token ID that indicates padding.
Methods
__init__
(kv_cache_state, previous_tokens, pad_id)Attributes
kv_cache_state
previous_tokens
pad_id
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.