temperature_sample_pyloop

temperature_sample_pyloop#

penzai.example_models.gemma.simple_decoding_loop.temperature_sample_pyloop(model: sampling_mode.GemmaKVCachingTransformer, initial_cache_state: sampling_mode.GemmaKVCachingState, prompt: pz.nx.NamedArray, rng: jax.Array, pad_id: int, temperature: float = 1.0, max_sampling_steps: int | None = None) pz.nx.NamedArray[source]#

Runs temperature sampling in a Python for loop.

Parameters:
  • model – The converted model we are running inference with.

  • initial_cache_state – The initial cache state created while converting the model.

  • prompt – A named array of prompt tokens. Must have a “seq” axis along which the tokens for each batch element are arranged. Should usually start with the beginning of sequence token. This function assumes (but does not check) that all non-padding tokens preceed all padding tokens.

  • rng – JAX PRNGKey to use for sampling.

  • pad_id – Token ID that corresponds to padding.

  • temperature – Temperature to sample at.

  • max_sampling_steps – Maximum number of sampling steps to run. If None, samples until filling up the key-value cache.

Returns:

A named array of continuations of the prompt.