build_gpt_neox_block

build_gpt_neox_block#

penzai.experimental.v2.models.transformer.variants.gpt_neox.build_gpt_neox_block(name: str, init_base_rng: jax.Array | None, config: GPTNeoXTransformerConfig) model_parts.TransformerBlock[source]#

Builds a GPT-NeoX “parallel” transformer block from a configuration.

GPT-NeoX uses a parallel formulation of transformer blocks, where the input of the previous block is fed to the attention and feedforward components at the same time:

y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x))

Parameters:
  • name – Name of the block.

  • init_base_rng – Base RNG for initializing the parameters.

  • config – The configuration of the model.

Returns:

A full transformer block.