GemmaTransformer#
- class penzai.deprecated.v1.example_models.gemma.model_core.GemmaTransformer[source]#
Bases:
Layer
Top-level Gemma transformer decoder, encapsulating all internal effects.
This class represents the full Gemma model, and can be loaded from the official Gemma checkpoints.
- Variables:
config (GemmaTransformerConfig) – The configuration for the transformer. Although not directly used when the model is called, this can be useful for re-building the model or converting it to autoregressive sampling mode.
body (pz.LayerLike) – The implementation of the transformer. Usually a side-input effect handler wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.
Methods
__init__
(config, body)from_config
(config)Constructs an uninitialized transformer with the Gemma architecture.
from_pretrained
(ckpt_params[, ...])Constructs a
GemmaTransformer
from an official Gemma Flax checkpoint.input_structure
()output_structure
()__call__
(inputs)Scores log-probabilities for the given inputs.
Attributes
config
body
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(inputs: GemmaInputs) pz.nx.NamedArray [source]#
Scores log-probabilities for the given inputs.
- Parameters:
inputs – Structure of input arguments, containing tokens, segment positions, and an attention mask.
- Returns:
The final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.
- classmethod from_config(config: GemmaTransformerConfig) GemmaTransformer [source]#
Constructs an uninitialized transformer with the Gemma architecture.
- Parameters:
config – The configuration of the Gemma model.
- Returns:
A
GemmaTransformer
with uninitialized parameters. All side input effects will have already been appropriately handled.
- classmethod from_pretrained(ckpt_params: dict[str, Any], upcast_activations_to_float32: bool = False) GemmaTransformer [source]#
Constructs a
GemmaTransformer
from an official Gemma Flax checkpoint.The parameters of the loaded
GemmaTransformer
will be close to those in the original checkpoint with a few modifications:Query, key, and value heads are stored in three separate matrices instead of being stored either as a single matrix (qkv_einsum) or as two (q_einsum and kv_einsum).
RMSLayerNorm
weights have their values increased by one, instead of adding one at call time.Axes of parameters are identified by name instead of by position.
- Parameters:
ckpt_params – Nested dictionary of weights from the Gemma checkpoint.
upcast_activations_to_float32 – Whether to cast activations to float32 when the model runs. This is useful for doing interpretability research at higher precision without consuming additional memory.
- Returns:
A GemmaTransformer model containing the loaded parameters.