gemma_from_pretrained_checkpoint

gemma_from_pretrained_checkpoint#

penzai.models.transformer.variants.gemma.gemma_from_pretrained_checkpoint(ckpt_params: dict[str, Any], upcast_activations_to_float32: bool = False, use_layer_stack: bool = False, preset_name: Literal['gemma_2b', 'gemma_7b', 'gemma2_2b', 'gemma2_9b', 'gemma2_27b', 'gemma3_1b', 'gemma3_4b', 'gemma3_12b', 'gemma3_27b', 'auto'] = 'auto') model_parts.TransformerLM[source]#

Builds a Gemma model from a pretrained checkpoint.

The parameters of the loaded Transformer will be close to those in the original checkpoint with a few modifications:

  • Query, key, and value heads are stored in three separate matrices instead of being stored either as a single matrix (qkv_einsum) or as two (q_einsum and kv_einsum).

  • RMSLayerNorm weights have their values increased by one, instead of adding one at call time.

  • Axes of parameters are identified by name instead of by position.

Parameters:
  • ckpt_params – Nested dictionary of weights from the Gemma checkpoint.

  • upcast_activations_to_float32 – Whether to cast activations to float32 when the model runs. This allows analyzing activations at higher precision without consuming additional memory for parameters.

  • use_layer_stack – Whether to use a layer stack for the decoder blocks.

  • preset_name – Preset name, used to determine model config. If “auto”, uses the number of layers and whether the model needs qk norm in the checkpoint to determine the configuration.

Returns:

A Transformer model containing the loaded parameters.