EmbeddingTable#

class penzai.nn.embeddings.EmbeddingTable[source]#

Bases: Struct

A table of embedding vectors for a vocabulary of tokens.

EmbeddingTable owns the embedding parameters used when either encoding a token to an embedding or decoding an embedding vector to a distribution over tokens. It does not directly provide callable methods, and should be wrapped in a EmbeddingLookup or EmbeddingDecode layer before being inserted into a model. This is to allow the same initialization logic to be shared between the two methods, and to simplify parameter sharing when tying the embeddings between the first and last layers of a language model.

If you wish to set up weight tying between the encoding and decoding steps, you can wrap the embedding table in pz.nn.mark_shareable, and then wrap the entire model (including both uses of the table) in pz.nn.attach_shared_parameters. (See the Gemma example model for an example of this pattern.)

Variables:
  • embeddings (parameters.ParameterLike[named_axes.NamedArray]) – The embedding parameters. One axis corresponds to the vocabulary, and all other axes will be considered part of the embedding. (Usually, there will only be one other axis.)

  • vocabulary_axis (str) – The name of the axis that corresponds to the vocabulary. This axis will be indexed into when performing embedding lookups.

Methods

__init__(embeddings, vocabulary_axis)

from_config(vocab_size, embedding_axes[, ...])

Constructs an EmbeddingTable with uninitialized parameters.

Attributes

embeddings

vocabulary_axis

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

classmethod from_config(vocab_size: int, embedding_axes: dict[str, int], vocabulary_axis: str = 'vocabulary', initializer: linear_and_affine.LinearOperatorWeightInitializer = functools.partial(<function variance_scaling_initializer>, scale=1.0, mode='fan_out', distribution='normal'), dtype: np.typing.DTypeLike = <class 'numpy.float32'>) EmbeddingTable[source]#

Constructs an EmbeddingTable with uninitialized parameters.

Parameters:
  • vocab_size – The size of the vocabulary.

  • embedding_axes – A dictionary mapping embedding axis names to their sizes. Will usually be a single-element dictionary of the form {"embedding": embedding_size}.

  • vocabulary_axis – The name of the axis that corresponds to the vocabulary. This axis will be indexed into when performing embedding lookups. Must not appear in embedding_axes.

  • initializer – A weight initializer that will be used to initialize the parameters. For the purposes of initialization, the “input axes” are of dimension 1, and the “output axes” are the embedding_axes; the vocabulary axis is treated as a parallel axis. Defaults to fan-out normalization over the embedding axes.

  • dtype – The data type of the embedding parameters.

Returns:

An EmbeddingTable with uninitialized parameters of the given shape.