GemmaTransformer#

class penzai.example_models.gemma.model_core.GemmaTransformer[source]#

Bases: Layer

Top-level Gemma transformer decoder, encapsulating all internal effects.

This class represents the full Gemma model, and can be loaded from the official Gemma checkpoints.

Variables:
  • config (GemmaTransformerConfig) – The configuration for the transformer. Although not directly used when the model is called, this can be useful for re-building the model or converting it to autoregressive sampling mode.

  • body (pz.LayerLike) – The implementation of the transformer. Usually a side-input effect handler wrapping the main sequence of transformer blocks, but may be modified after the model is loaded due to patching.

Methods

__init__(config, body)

from_config(config)

Constructs an uninitialized transformer with the Gemma architecture.

from_pretrained(ckpt_params[, ...])

Constructs a GemmaTransformer from the official Gemma Flax checkpoint.

input_structure()

output_structure()

__call__(inputs)

Scores log-probabilities for the given inputs.

Attributes

config

body

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(inputs: GemmaInputs) pz.nx.NamedArray[source]#

Scores log-probabilities for the given inputs.

Parameters:

inputs – Structure of input arguments, containing tokens, segment positions, and an attention mask.

Returns:

The final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.

classmethod from_config(config: GemmaTransformerConfig) GemmaTransformer[source]#

Constructs an uninitialized transformer with the Gemma architecture.

Parameters:

config – The configuration of the Gemma model.

Returns:

A GemmaTransformer with uninitialized parameters. All side input effects will have already been appropriately handled.

classmethod from_pretrained(ckpt_params: dict[str, Any], upcast_activations_to_float32: bool = False) GemmaTransformer[source]#

Constructs a GemmaTransformer from the official Gemma Flax checkpoint.

The parameters of the loaded GemmaTransformer will be close to those in the original checkpoint with a few modifications:

  • Query, key, and value heads are stored in three separate matrices instead of being stored either as a single matrix (qkv_einsum) or as two (q_einsum and kv_einsum).

  • RMSLayerNorm weights have their values increased by one, instead of adding one at call time.

  • Axes of parameters are identified by name instead of by position.

Parameters:
  • ckpt_params – Nested dictionary of weights from the Gemma checkpoint.

  • upcast_activations_to_float32 – Whether to cast activations to float32 when the model runs. This is useful for doing interpretability research at higher precision without consuming additional memory.

Returns:

A GemmaTransformer model containing the loaded parameters.