prefill#
- penzai.deprecated.v1.example_models.gemma.simple_decoding_loop.prefill(model: sampling_mode.GemmaKVCachingTransformer, initial_cache_state: sampling_mode.GemmaKVCachingState, prompt: pz.nx.NamedArray, pad_id: int) tuple[pz.nx.NamedArray, SamplingState] [source]#
Prefills the key-value caches based on a prompt.
- Parameters:
model – The converted model we are running inference with.
initial_cache_state – The initial cache state created while converting the model.
prompt – A named array of prompt tokens. Must have a “seq” axis along which the tokens for each batch element are arranged. Should usually start with the beginning of sequence token. This function assumes (but does not check) that all non-padding tokens preceed all padding tokens.
pad_id – Token ID that corresponds to padding.
- Returns:
A tuple
(next_log_probs, sampling_state)
, wherenext_log_probs
are the log-probabilities of the next token to sample, andsampling_state
is a state that can be passed to future sampling calls.