build_gpt_neox_block#
- penzai.experimental.v2.models.transformer.variants.gpt_neox.build_gpt_neox_block(name: str, init_base_rng: jax.Array | None, config: GPTNeoXTransformerConfig) model_parts.TransformerBlock[source]#
Builds a GPT-NeoX “parallel” transformer block from a configuration.
GPT-NeoX uses a parallel formulation of transformer blocks, where the input of the previous block is fed to the attention and feedforward components at the same time:
y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x))
- Parameters:
name – Name of the block.
init_base_rng – Base RNG for initializing the parameters.
config – The configuration of the model.
- Returns:
A full transformer block.