LlamalikeTransformerConfig#

class penzai.models.transformer.variants.llamalike_common.LlamalikeTransformerConfig[source]#

Bases: object

Common configuration parameters for a “llama-like” transformer.

This config encompasses the parameters for the Llama, Mistral, and Gemma model families.

These are held in a single configuration object to simplify argument passing during construction of the model.

Variables:
  • num_kv_heads (int) – The number of key-value attention heads or head groups.

  • query_head_multiplier (int) – The number of query heads for each KV head.

  • embedding_dim (int) – Dimension of the embedding vectors and residual stream.

  • projection_dim (int) – Dimension of the query, key, and value projections. Usually embedding_dim // num_heads.

  • mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).

  • num_decoder_blocks (int) – Number of transformer decoder blocks in the model.

  • vocab_size (int) – Number of tokens in the vocabulary.

  • mlp_variant (Literal['geglu_exact', 'geglu_approx', 'swiglu']) – Gated linear unit variant for MLPs.

  • tie_embedder_and_logits (bool) – Whether to tie the weights of the input token embedding and output logit layers. If True, also scales down input token embeddings by sqrt(embedding_dim). (This is used by Gemma.)

  • rope_wavelength (float) – Wavelength for global RoPE layers (and for local RoPE layers if local_rope_wavelength is not set).

  • rms_norm_eps (float) – Epsilon for RMSNorm layers.

  • attention_type (AttentionType | Sequence[AttentionType]) – A single attention type or sequence of per-layer attention types. If a sequence, its length should evenly divide the number of decoder blocks, and will be repeated to match the number of blocks.

  • use_post_attn_norm (bool) – Whether to add a normalization layer after the attention block.

  • use_post_ffw_norm (bool) – Whether to add a normalization layer after the feedforward block.

  • final_logit_softcap (float | None) – If not None, used as the tanh soft cap for the final transformer logits.

  • attn_logits_soft_cap (float | None) – If not None, used as the tanh soft cap for the attention logits.

  • query_scaling_factor (float | Literal['default']) – Scaling factor for the query vectors. If “default”, defaults to 1 / sqrt(projection_dim).

  • parameter_dtype (jax.typing.DTypeLike) – Floating dtype to use for all parameters.

  • activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.

  • use_layer_stack (bool) – Whether to stack the blocks together using a LayerStack.

  • use_qk_norm (bool) – Whether to use QK normalization.

  • global_scale_factor (float | None) – Scale factor for the global RoPE layers (scale factor for the local RoPE layers is set as 1.0 by default).

  • local_rope_wavelength (float | None) – Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength).

Methods

__init__(*, num_kv_heads, ...[, ...])

Attributes

attention_type

attn_logits_soft_cap

final_logit_softcap

global_scale_factor

local_rope_wavelength

query_scaling_factor

rms_norm_eps

rope_wavelength

use_layer_stack

use_post_attn_norm

use_post_ffw_norm

use_qk_norm

num_kv_heads

query_head_multiplier

embedding_dim

projection_dim

mlp_hidden_dim

num_decoder_blocks

vocab_size

mlp_variant

tie_embedder_and_logits

activation_dtype#

alias of float32

parameter_dtype#

alias of float32