gemma_from_pretrained_checkpoint#
- penzai.experimental.v2.models.transformer.variants.gemma.gemma_from_pretrained_checkpoint(ckpt_params: dict[str, Any], upcast_activations_to_float32: bool = False, use_layer_stack: bool = False) model_parts.TransformerLM[source]#
Builds a Gemma model from a pretrained checkpoint.
The parameters of the loaded
Transformerwill be close to those in the original checkpoint with a few modifications:Query, key, and value heads are stored in three separate matrices instead of being stored either as a single matrix (qkv_einsum) or as two (q_einsum and kv_einsum).
RMSLayerNormweights have their values increased by one, instead of adding one at call time.Axes of parameters are identified by name instead of by position.
- Parameters:
ckpt_params – Nested dictionary of weights from the Gemma checkpoint.
upcast_activations_to_float32 – Whether to cast activations to float32 when the model runs. This allows analyzing activations at higher precision without consuming additional memory for parameters.
use_layer_stack – Whether to use a layer stack for the decoder blocks.
- Returns:
A Transformer model containing the loaded parameters.