TransformerMetadata#
- class penzai.models.transformer.model_parts.TransformerMetadata[source]#
Bases:
objectCommon axis sizes and other information for transformer models.
These values are kept on the main transformer object to simplify model transformations that depend on axis sizes or dtypes, by making it possible to infer the shape of intermediate activations in advance.
- Variables:
common_head_axes (dict[str, int]) – A map of axis names to sizes for head axes that are common to queries, keys, and values.
query_only_head_axes (dict[str, int]) – A map of axis names to sizes for head axes that are only used for queries.
embedding_dim (int) – Dimension of the embedding vectors and residual stream.
projection_dim (int) – Dimension of the query, key, and value projections.
mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).
vocab_size (int) – Number of tokens in the vocabulary.
activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.
Methods
__init__(common_head_axes, ...)Attributes
common_head_axesquery_only_head_axesembedding_dimprojection_dimmlp_hidden_dimvocab_sizeactivation_dtype