penzai.example_models.gemma.simple_decoding_loop.prefill(model: sampling_mode.GemmaKVCachingTransformer, initial_cache_state: sampling_mode.GemmaKVCachingState, prompt: pz.nx.NamedArray, pad_id: int) tuple[pz.nx.NamedArray, SamplingState][source]#

Prefills the key-value caches based on a prompt.

  • model – The converted model we are running inference with.

  • initial_cache_state – The initial cache state created while converting the model.

  • prompt – A named array of prompt tokens. Must have a “seq” axis along which the tokens for each batch element are arranged. Should usually start with the beginning of sequence token. This function assumes (but does not check) that all non-padding tokens preceed all padding tokens.

  • pad_id – Token ID that corresponds to padding.


A tuple (next_log_probs, sampling_state), where next_log_probs are the log-probabilities of the next token to sample, and sampling_state is a state that can be passed to future sampling calls.