advance_one_token

advance_one_token#

penzai.example_models.gemma.simple_decoding_loop.advance_one_token(model: sampling_mode.GemmaKVCachingTransformer, state: SamplingState, next_token: jax.Array) tuple[pz.nx.NamedArray, SamplingState][source]#

Advances a sampling state by one token.

This can be used to feed new sampled tokens one-at-a-time through the model, producing new log-probs that can be used to sample new tokens.

Parameters:
  • model – The converted model we are running inference with.

  • state – The current sampling state.

  • next_token – The next token to feed. Should not have a “seq” axis.

Returns:

A tuple (next_log_probs, sampling_state), where next_log_probs are the log-probabilities of the next token to sample, and sampling_state is a state that can be passed to future sampling calls.