LlamalikeTransformerConfig#
- class penzai.models.transformer.variants.llamalike_common.LlamalikeTransformerConfig[source]#
Bases:
objectCommon configuration parameters for a “llama-like” transformer.
This config encompasses the parameters for the Llama, Mistral, and Gemma model families.
These are held in a single configuration object to simplify argument passing during construction of the model.
- Variables:
num_kv_heads (int) – The number of key-value attention heads or head groups.
query_head_multiplier (int) – The number of query heads for each KV head.
embedding_dim (int) – Dimension of the embedding vectors and residual stream.
projection_dim (int) – Dimension of the query, key, and value projections. Usually
embedding_dim // num_heads.mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).
num_decoder_blocks (int) – Number of transformer decoder blocks in the model.
vocab_size (int) – Number of tokens in the vocabulary.
mlp_variant (Literal['geglu_approx', 'swiglu']) – Gated linear unit variant for MLPs.
tie_embedder_and_logits (bool) – Whether to tie the weights of the input token embedding and output logit layers. If True, also scales down input token embeddings by sqrt(embedding_dim). (This is used by Gemma.)
rope_wavelength (float) – Wavelength for RoPE layers.
rms_norm_eps (float) – Epsilon for RMSNorm layers.
attention_type (AttentionType | Sequence[AttentionType]) – A single attention type or sequence of per-layer attention types. If a sequence, its length should evenly divide the number of decoder blocks, and will be repeated to match the number of blocks.
use_post_attn_norm (bool) – Whether to add a normalization layer after the attention block.
use_post_ffw_norm (bool) – Whether to add a normalization layer after the feedforward block.
final_logit_softcap (float | None) – If not None, used as the tanh soft cap for the final transformer logits.
attn_logits_soft_cap (float | None) – If not None, used as the tanh soft cap for the attention logits.
query_scaling_factor (float | Literal['default']) – Scaling factor for the query vectors. If “default”, defaults to 1 / sqrt(projection_dim).
parameter_dtype (jax.typing.DTypeLike) – Floating dtype to use for all parameters.
activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.
use_layer_stack (bool) – Whether to stack the blocks together using a LayerStack.
Methods
__init__(*, num_kv_heads, ...[, ...])Attributes
attention_typeattn_logits_soft_capfinal_logit_softcapquery_scaling_factorrms_norm_epsrope_wavelengthuse_layer_stackuse_post_attn_normuse_post_ffw_normnum_kv_headsquery_head_multiplierembedding_dimprojection_dimmlp_hidden_dimnum_decoder_blocksvocab_sizemlp_varianttie_embedder_and_logits