TransformerMetadata#
- class penzai.models.transformer.model_parts.TransformerMetadata[source]#
Bases:
object
Common axis sizes and other information for transformer models.
These values are kept on the main transformer object to simplify model transformations that depend on axis sizes or dtypes, by making it possible to infer the shape of intermediate activations in advance.
- Variables:
common_head_axes (dict[str, int]) – A map of axis names to sizes for head axes that are common to queries, keys, and values.
query_only_head_axes (dict[str, int]) – A map of axis names to sizes for head axes that are only used for queries.
embedding_dim (int) – Dimension of the embedding vectors and residual stream.
projection_dim (int) – Dimension of the query, key, and value projections.
mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).
vocab_size (int) – Number of tokens in the vocabulary.
activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.
Methods
__init__
(common_head_axes, ...)Attributes
common_head_axes
query_only_head_axes
embedding_dim
projection_dim
mlp_hidden_dim
vocab_size
activation_dtype