model_parts

model_parts#

Core components of a Transformer language model.

Specific instantiations of the TransformerLM model will use the following axis naming conventions:

  • “seq” is the temporal axis of the token sequence, i.e. the axis along which the prompt tokens are laid out. In an attention matrix, it specifically refers to the query token(s) (the ones we are currently processing).

  • “embedding” is the axis for embedding vectors and the residual stream.

  • “projection” is the axis for query, key, and value head projection vectors, i.e. the axis where query-key dot products happen, and for which the values of attention heads are retrieved.

  • “heads”, “head_groups”, and “query_heads” are axes for attention heads, depending on whether full multi-head, multi-query, or grouped-query attention are used.

    • In full multi-head attention, the “heads” axis appears in queries, keys, and values.

    • In multi-query attention, the “query_heads” axis appears in queries, and keys and values do not have a heads axis.

    • In grouped-query attention, the “head_groups” axis appears in queries, keys, and values, and the “query_heads” axis appears in queries only.

  • “kv_seq” is a temporary copy of the “seq” axis that represents the position of the keys and values in an attention matrix.

  • “neurons” is the axis for the neurons in the MLP blocks, which have an activation function applied elementwise and therefore have a priveleged basis.

Additionally, they use the following side input names:

  • “token_positions” is the name of the side input that provides the position of each token for the purposes of positional embeddings and causal attention masking. -1 indicates a padding token.

The KV caching logic is defined in the separate module sampling_mode.

Classes

TransformerBlock

Informatively-named Sequential subclass for the main transformer blocks.

TransformerFeedForward

Informatively-named Sequential subclass for feedforward/MLP layers.

TransformerLM

Top-level transformer decoder wrapper.

TransformerMetadata

Common axis sizes and other information for transformer models.