temperature_sample_pyloop#
- penzai.models.transformer.simple_decoding_loop.temperature_sample_pyloop(model: sampling_mode.KVCachingTransformerLM, prompt: pz.nx.NamedArray, rng: jax.Array, temperature: float = 1.0, max_sampling_steps: int | None = None) pz.nx.NamedArray [source]#
Runs temperature sampling in a Python for loop.
- Parameters:
model – The converted model we are running inference with.
prompt – A named array of prompt tokens. Must have a “seq” axis along which the tokens for each batch element are arranged. Should always have at least one non-padding token along the “seq” axis (usually the beginning- of-sequence token).
rng – JAX PRNGKey to use for sampling.
temperature – Temperature to sample at.
max_sampling_steps – Maximum number of sampling steps to run. If None, samples until filling up the key-value cache.
- Returns:
A named array of continuations of the prompt.