EmbeddingTable#
- class penzai.nn.embeddings.EmbeddingTable[source]#
Bases:
Struct
A table of embedding vectors for a vocabulary of tokens.
EmbeddingTable
owns the embedding parameters used when either encoding a token to an embedding or decoding an embedding vector to a distribution over tokens. It does not directly provide callable methods, and should be wrapped in aEmbeddingLookup
orEmbeddingDecode
layer before being inserted into a model. This is to allow the same initialization logic to be shared between the two methods, and to simplify parameter sharing when tying the embeddings between the first and last layers of a language model.If you wish to set up weight tying between the encoding and decoding steps, you should initialize a single embedding table, and then pass that same table to the
EmbeddingLookup
andEmbeddingDecode
layers.- Variables:
embeddings (parameters.ParameterLike[named_axes.NamedArray]) – The embedding parameters. One axis corresponds to the vocabulary, and all other axes will be considered part of the embedding. (Usually, there will only be one other axis.)
vocabulary_axis (str) – The name of the axis that corresponds to the vocabulary. This axis will be indexed into when performing embedding lookups.
Methods
__init__
(embeddings, vocabulary_axis)from_config
(name, init_base_rng, vocab_size, ...)Constructs an
EmbeddingTable
.Attributes
embeddings
vocabulary_axis
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- classmethod from_config(name: str, init_base_rng: jax.Array | None, vocab_size: int, embedding_axes: dict[str, int], vocabulary_axis: str = 'vocabulary', initializer: linear_and_affine.LinearOperatorWeightInitializer = functools.partial(<function variance_scaling_initializer>, scale=1.0, mode='fan_out', distribution='normal'), dtype: np.typing.DTypeLike = <class 'numpy.float32'>) EmbeddingTable [source]#
Constructs an
EmbeddingTable
.- Parameters:
name – The name of the layer.
init_base_rng – The base RNG to use for initializing model parameters.
vocab_size – The size of the vocabulary.
embedding_axes – A dictionary mapping embedding axis names to their sizes. Will usually be a single-element dictionary of the form
{"embedding": embedding_size}
.vocabulary_axis – The name of the axis that corresponds to the vocabulary. This axis will be indexed into when performing embedding lookups. Must not appear in
embedding_axes
.initializer – A weight initializer that will be used to initialize the parameters. For the purposes of initialization, the “input axes” are of dimension 1, and the “output axes” are the
embedding_axes
; the vocabulary axis is treated as a parallel axis. Defaults to fan-out normalization over the embedding axes.dtype – The data type of the embedding parameters.
- Returns:
An EmbeddingTable with of the given shape.