temperature_sample_pyloop#
- penzai.example_models.gemma.simple_decoding_loop.temperature_sample_pyloop(model: sampling_mode.GemmaKVCachingTransformer, initial_cache_state: sampling_mode.GemmaKVCachingState, prompt: pz.nx.NamedArray, rng: jax.Array, pad_id: int, temperature: float = 1.0, max_sampling_steps: int | None = None) pz.nx.NamedArray [source]#
Runs temperature sampling in a Python for loop.
- Parameters:
model – The converted model we are running inference with.
initial_cache_state – The initial cache state created while converting the model.
prompt – A named array of prompt tokens. Must have a “seq” axis along which the tokens for each batch element are arranged. Should usually start with the beginning of sequence token. This function assumes (but does not check) that all non-padding tokens preceed all padding tokens.
rng – JAX PRNGKey to use for sampling.
pad_id – Token ID that corresponds to padding.
temperature – Temperature to sample at.
max_sampling_steps – Maximum number of sampling steps to run. If None, samples until filling up the key-value cache.
- Returns:
A named array of continuations of the prompt.