model_core#
Core layers for the Gemma model architecture.
See the Gemma technical report at https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf and the accompanying reference implementation at google-deepmind/gemma.
All of the layers and models in this file use the following axis naming convention:
“seq” is the temporal axis of the token sequence, i.e. the axis along which the prompt tokens are laid out. In an attention matrix, it specifically refers to the query token(s) (the ones we are currently processing).
“embedding” is the axis for embedding vectors and the residual stream.
“projection” is the axis for query, key, and value head projection vectors, i.e. the axis where query-key dot products happen, and for which the values of attention heads are retrieved.
“heads” is the axis that ranges across the different attention heads. Note that depending on the configuration, the key and value computations may not have this axis, because they are shared across heads.
“kv_seq” is a temporary copy of the “seq” axis that represents the position of the keys and values in an attention matrix.
“neurons” is the axis for the neurons in the MLP blocks, which have an activation function (GEGLU) applied elementwise and therefore have a priveleged basis.
Additionally, they use the following effect tags:
“token_positions” is the name of the side input that provides the position of each token for the purposes of positional embeddings.
“attn_mask” is the name of the side input that provides the attention mask for each attention layer.
Where applicable, “cache_end_index” is the name of the side input that identifies the current length of the key/value cache state. This determines where the new keys and values are inserted into the cache.
Where applicable, “kv_cache” is the name of the local state category that contains all key/value caches.
Note that the top-level GemmaTransformer
and GemmaKVCachingTransformer
classes will handle these effects for you in most cases, so this is most
relevant if you plan to initialize parts of the transformer without using these
top-level classes.
The KV caching logic is defined in the separate module
penzai.deprecated.v1.example_models.gemma.sampling_mode
.
Classes
Gemma-specific configuration of the self-attention layer. |
|
Implementation of the feed-forward block in Gemma. |
|
Input structure for |
|
Top-level Gemma transformer decoder, encapsulating all internal effects. |
|
Main decoder block for the Gemma transformer architecture. |
|
Common configuration parameters for the Gemma transformer architecture. |