sampling_mode

sampling_mode#

Sampling-mode adapters for TransformerLM models.

This file includes the kv-cache sampling mode of the base TransformerLM model. This mode is intended to be hot-swapped for the main TransformerLM implementation: you should generally start by loading a model_parts.TransformerLM and then converting it to a KVCachingTransformerLM using KVCachingTransformerLM.from_uncached.

The layers defined here follow the same conventions documented in the module docstring for model_parts. In addition:

  • Where applicable, “kv_token_positions” is the name of the side input that provides the position of each token for the purposes of positional embeddings.

  • Where applicable, “cache_end_index” is the name of the side input that identifies the current length of the key/value cache state.

Classes

KVCachingTransformerLM

Top-level transformer in (stateful) cached autoregressive sampling mode.