GemmaInputs#

class penzai.example_models.gemma.model_core.GemmaInputs[source]#

Bases: Struct

Input structure for GemmaTransformer.

Variables:
  • tokens (pz.nx.NamedArray) – Sequence of tokens, as an integer named array with a “seq” axis and possibly batch axes. Usually starts with the beginning-of-sequence token.

  • positions (pz.nx.NamedArray) – Sequence of token positions, as an integer named array with a “seq” axis and possibly batch axes. Usually starts from 0 and increments along the “seq” axis, but can be different to support e.g. example packing.

  • attention_mask (pz.nx.NamedArray) – Boolean attention mask with “seq” and “kv_seq” axes of the same length, and possibly batch axes. Usually a causal mask, but can be different to support e.g. example packing or dropping out inputs.

Methods

__init__(tokens, positions, attention_mask)

from_basic_segments(tokens)

Constructs a simple input structure for a batch of single segments.

Attributes

tokens

positions

attention_mask

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

classmethod from_basic_segments(tokens: pz.nx.NamedArray) GemmaInputs[source]#

Constructs a simple input structure for a batch of single segments.

This can be used to process inputs that do not need advanced position or attention mask handling, and which just consist of ordinary sequences that are not packed together or padded. It augments the tokens with a standard position array and causal attention mask, as expected by the Gemma model.

Parameters:

tokens – Sequence of tokens, as an integer named array with a “seq” axis and possibly batch axes, which starts with the beginning-of-sequence token. Each 1d vector along the “seq” axis should represent an unpadded sequence.

Returns:

A full input structure containing the provided tokens, along with a simple incrementing position array and a causal mask.