TransformerMetadata

TransformerMetadata#

class penzai.models.transformer.model_parts.TransformerMetadata[source]#

Bases: object

Common axis sizes and other information for transformer models.

These values are kept on the main transformer object to simplify model transformations that depend on axis sizes or dtypes, by making it possible to infer the shape of intermediate activations in advance.

Variables:
  • common_head_axes (dict[str, int]) – A map of axis names to sizes for head axes that are common to queries, keys, and values.

  • query_only_head_axes (dict[str, int]) – A map of axis names to sizes for head axes that are only used for queries.

  • embedding_dim (int) – Dimension of the embedding vectors and residual stream.

  • projection_dim (int) – Dimension of the query, key, and value projections.

  • mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).

  • vocab_size (int) – Number of tokens in the vocabulary.

  • activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.

Methods

__init__(common_head_axes, ...)

Attributes

common_head_axes

query_only_head_axes

embedding_dim

projection_dim

mlp_hidden_dim

vocab_size

activation_dtype