GemmaTransformerConfig#
- class penzai.deprecated.v1.example_models.gemma.model_core.GemmaTransformerConfig[source]#
Bases:
object
Common configuration parameters for the Gemma transformer architecture.
These are held in a single configuration object to simplify argument passing during construction of the model.
- Variables:
num_heads (int) – The number of attention heads to use.
embedding_dim (int) – Dimension of the embedding vectors and residual stream.
projection_dim (int) – Dimension of the query, key, and value projections. Usually
embedding_dim // num_heads
.single_kv_head (bool) – Whether a single key head and value head should be shared across all query heads.
mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).
num_decoder_blocks (int) – Number of transformer decoder blocks in the model.
vocab_size (int) – Number of tokens in the vocabulary.
parameter_dtype (jax.typing.DTypeLike) – Floating dtype to use for all parameters.
activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.
Methods
__init__
(num_heads, embedding_dim, ...)Attributes
num_heads
embedding_dim
projection_dim
single_kv_head
mlp_hidden_dim
num_decoder_blocks
vocab_size
parameter_dtype
activation_dtype