advance_one_token#
- penzai.deprecated.v1.example_models.gemma.simple_decoding_loop.advance_one_token(model: sampling_mode.GemmaKVCachingTransformer, state: SamplingState, next_token: jax.Array) tuple[pz.nx.NamedArray, SamplingState] [source]#
Advances a sampling state by one token.
This can be used to feed new sampled tokens one-at-a-time through the model, producing new log-probs that can be used to sample new tokens.
- Parameters:
model – The converted model we are running inference with.
state – The current sampling state.
next_token – The next token to feed. Should not have a “seq” axis.
- Returns:
A tuple
(next_log_probs, sampling_state)
, wherenext_log_probs
are the log-probabilities of the next token to sample, andsampling_state
is a state that can be passed to future sampling calls.