model_core

model_core#

Core layers for the Gemma model architecture.

See the Gemma technical report at https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf and the accompanying reference implementation at google-deepmind/gemma.

All of the layers and models in this file use the following axis naming convention:

  • “seq” is the temporal axis of the token sequence, i.e. the axis along which the prompt tokens are laid out. In an attention matrix, it specifically refers to the query token(s) (the ones we are currently processing).

  • “embedding” is the axis for embedding vectors and the residual stream.

  • “projection” is the axis for query, key, and value head projection vectors, i.e. the axis where query-key dot products happen, and for which the values of attention heads are retrieved.

  • “heads” is the axis that ranges across the different attention heads. Note that depending on the configuration, the key and value computations may not have this axis, because they are shared across heads.

  • “kv_seq” is a temporary copy of the “seq” axis that represents the position of the keys and values in an attention matrix.

  • “neurons” is the axis for the neurons in the MLP blocks, which have an activation function (GEGLU) applied elementwise and therefore have a priveleged basis.

Additionally, they use the following effect tags:

  • “token_positions” is the name of the side input that provides the position of each token for the purposes of positional embeddings.

  • “attn_mask” is the name of the side input that provides the attention mask for each attention layer.

  • Where applicable, “cache_end_index” is the name of the side input that identifies the current length of the key/value cache state. This determines where the new keys and values are inserted into the cache.

  • Where applicable, “kv_cache” is the name of the local state category that contains all key/value caches.

Note that the top-level GemmaTransformer and GemmaKVCachingTransformer classes will handle these effects for you in most cases, so this is most relevant if you plan to initialize parts of the transformer without using these top-level classes.

The KV caching logic is defined in the separate module penzai.deprecated.v1.example_models.gemma.sampling_mode.

Classes

GemmaAttention

Gemma-specific configuration of the self-attention layer.

GemmaFeedForward

Implementation of the feed-forward block in Gemma.

GemmaInputs

Input structure for GemmaTransformer.

GemmaTransformer

Top-level Gemma transformer decoder, encapsulating all internal effects.

GemmaTransformerBlock

Main decoder block for the Gemma transformer architecture.

GemmaTransformerConfig

Common configuration parameters for the Gemma transformer architecture.