temperature_sample_pyloop

temperature_sample_pyloop#

penzai.models.transformer.simple_decoding_loop.temperature_sample_pyloop(model: sampling_mode.KVCachingTransformerLM, prompt: pz.nx.NamedArray, rng: jax.Array, temperature: float = 1.0, max_sampling_steps: int | None = None) pz.nx.NamedArray[source]#

Runs temperature sampling in a Python for loop.

Parameters:
  • model – The converted model we are running inference with.

  • prompt – A named array of prompt tokens. Must have a “seq” axis along which the tokens for each batch element are arranged. Should always have at least one non-padding token along the “seq” axis (usually the beginning- of-sequence token).

  • rng – JAX PRNGKey to use for sampling.

  • temperature – Temperature to sample at.

  • max_sampling_steps – Maximum number of sampling steps to run. If None, samples until filling up the key-value cache.

Returns:

A named array of continuations of the prompt.