GemmaFeedForward#

class penzai.example_models.gemma.model_core.GemmaFeedForward[source]#

Bases: Sequential

Implementation of the feed-forward block in Gemma.

Methods

__init__(sublayers)

from_config(embedding_dim, hidden_dim[, dtype])

Constructs an uninitialized Gemma feed-forward layer.

Attributes

sublayers

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

__call__(value)

Runs each of the sublayers in sequence.

classmethod from_config(embedding_dim: int, hidden_dim: int, dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>) GemmaFeedForward[source]#

Constructs an uninitialized Gemma feed-forward layer.

Gemma’s feedforward layer uses GELU-based gated linear units (GEGLU), as proposed by Shazeer (2020). We represent this computation as a composition of simpler Penzai primitives, to enable patching and post-processing of the various internal activations.

We assume that the input embedding axis is called “embedding”, and the neurons axis is called “neurons”. Other axes will be treated as batch dimensions and vectorized over.

Parameters:
  • embedding_dim – The dimensionality of the input and output embeddings.

  • hidden_dim – The dimensionality of the hidden layer.

  • dtype – The data type to use for the parameters.

Returns:

An instance of GemmaFeedForward containing uninitialized parameters of the appropriate shapes and dtypes.