GemmaTransformer#
- class penzai.example_models.gemma.model_core.GemmaTransformer[source]#
Bases:
LayerTop-level Gemma transformer decoder, encapsulating all internal effects.
This class represents the full Gemma model, and can be loaded from the official Gemma checkpoints.
- Variables:
config (GemmaTransformerConfig) – The configuration for the transformer. Although not directly used when the model is called, this can be useful for re-building the model or converting it to autoregressive sampling mode.
body (pz.LayerLike) – The implementation of the transformer. Usually a side-input effect handler wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.
Methods
__init__(config, body)from_config(config)Constructs an uninitialized transformer with the Gemma architecture.
from_pretrained(ckpt_params[, ...])Constructs a
GemmaTransformerfrom the official Gemma Flax checkpoint.input_structure()output_structure()__call__(inputs)Scores log-probabilities for the given inputs.
Attributes
configbodyInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(inputs: GemmaInputs) pz.nx.NamedArray[source]#
Scores log-probabilities for the given inputs.
- Parameters:
inputs – Structure of input arguments, containing tokens, segment positions, and an attention mask.
- Returns:
The final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.
- classmethod from_config(config: GemmaTransformerConfig) GemmaTransformer[source]#
Constructs an uninitialized transformer with the Gemma architecture.
- Parameters:
config – The configuration of the Gemma model.
- Returns:
A
GemmaTransformerwith uninitialized parameters. All side input effects will have already been appropriately handled.
- classmethod from_pretrained(ckpt_params: dict[str, Any], upcast_activations_to_float32: bool = False) GemmaTransformer[source]#
Constructs a
GemmaTransformerfrom the official Gemma Flax checkpoint.The parameters of the loaded
GemmaTransformerwill be close to those in the original checkpoint with a few modifications:Query, key, and value heads are stored in three separate matrices instead of being stored either as a single matrix (qkv_einsum) or as two (q_einsum and kv_einsum).
RMSLayerNormweights have their values increased by one, instead of adding one at call time.Axes of parameters are identified by name instead of by position.
- Parameters:
ckpt_params – Nested dictionary of weights from the Gemma checkpoint.
upcast_activations_to_float32 – Whether to cast activations to float32 when the model runs. This is useful for doing interpretability research at higher precision without consuming additional memory.
- Returns:
A GemmaTransformer model containing the loaded parameters.