GPTNeoXTransformerConfig#
- class penzai.models.transformer.variants.gpt_neox.GPTNeoXTransformerConfig[source]#
Bases:
objectConfiguration parameters for a GPT Neo-X transformer.
These are held in a single configuration object to simplify argument passing during construction of the model.
- Variables:
num_attention_heads (int) – The number of attention heads.
embedding_dim (int) – Dimension of the embedding vectors and residual stream.
projection_dim (int) – Dimension of the query, key, and value projections. Usually
embedding_dim // num_heads.mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).
num_decoder_blocks (int) – Number of transformer decoder blocks in the model.
vocab_size (int) – Number of tokens in the vocabulary.
activation_fn (Literal['relu', 'selu', 'gelu_exact', 'gelu_approx']) – Activation function
rope_subset_size (int) – Number of projection dimensions to allocate to rotary position embeddings.
rope_wavelength (float) – Wavelength for RoPE layers.
layernorm_epsilon (float) – Epsilon for layer normalization layers.
parameter_dtype (jax.typing.DTypeLike) – Floating dtype to use for all parameters.
activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.
use_layer_stack (bool) – Whether to stack the blocks together using a LayerStack.
Methods
__init__(num_attention_heads, embedding_dim, ...)Attributes
use_layer_stacknum_attention_headsembedding_dimprojection_dimmlp_hidden_dimnum_decoder_blocksvocab_sizeactivation_fnrope_subset_sizerope_wavelengthlayernorm_epsilonparameter_dtypeactivation_dtype