GPTNeoXTransformerConfig

GPTNeoXTransformerConfig#

class penzai.experimental.v2.models.transformer.variants.gpt_neox.GPTNeoXTransformerConfig[source]#

Bases: object

Configuration parameters for a GPT Neo-X transformer.

These are held in a single configuration object to simplify argument passing during construction of the model.

Variables:
  • num_attention_heads (int) – The number of attention heads.

  • embedding_dim (int) – Dimension of the embedding vectors and residual stream.

  • projection_dim (int) – Dimension of the query, key, and value projections. Usually embedding_dim // num_heads.

  • mlp_hidden_dim (int) – Dimensionality of the hidden layer of the MLP blocks in each layer (the “neurons” axis).

  • num_decoder_blocks (int) – Number of transformer decoder blocks in the model.

  • vocab_size (int) – Number of tokens in the vocabulary.

  • activation_fn (Literal['relu', 'selu', 'gelu_exact', 'gelu_approx']) – Activation function

  • rope_subset_size (int) – Number of projection dimensions to allocate to rotary position embeddings.

  • rope_wavelength (float) – Wavelength for RoPE layers.

  • layernorm_epsilon (float) – Epsilon for layer normalization layers.

  • parameter_dtype (jax.typing.DTypeLike) – Floating dtype to use for all parameters.

  • activation_dtype (jax.typing.DTypeLike) – Floating dtype to use for activations and KV cache tables.

  • use_layer_stack (bool) – Whether to stack the blocks together using a LayerStack.

Methods

__init__(num_attention_heads, embedding_dim, ...)

Attributes

use_layer_stack

num_attention_heads

embedding_dim

projection_dim

mlp_hidden_dim

num_decoder_blocks

vocab_size

activation_fn

rope_subset_size

rope_wavelength

layernorm_epsilon

parameter_dtype

activation_dtype