NamedEinsum

NamedEinsum#

class penzai.nn.linear_and_affine.NamedEinsum[source]#

Bases: Layer

An Einsum operation that contracts based on axis names.

This layer behaves like a standard einsum tensor contraction, but indexed by axis names instead of by position. In its full generality, it is specified based on mappings from each named axis to the summation index to use, e.g. an einsum “thp,shp->hts” could be specified as

NamedEinsum(
    (
        {"tokens":"t", "heads":"h", "projection":"p"},
        {"kv_tokens":"s", "heads":"h", "projection":"p"}
    ),
    {"heads":"h", "tokens":"t", "kv_tokens":"s"}
)

For the common case where each axis name should have its own summation index, you can also omit the values and just write something like

NamedEinsum(
    (
        ("tokens", "heads", "projection"),
        ("kv_tokens", "heads", "projection"),
    ),
    ("heads", "tokens", "kv_tokens"),
)

Additionally, arbitrary batch axes can be added as long as they are present in every array, and will be added to the output array.

Variables:
  • inputs_axes – Tuple of axis name specifications for each of the inputs. Each specification is either a mapping from axis names to a summation index name, or just a tuple of axis names if the summation indices should be the same as the axis names.

  • output_axes (tuple[str, ...] | dict[str, str]) – Specification of axis names in the output.

Methods

__init__(input_axes, output_axes)

input_structure()

output_structure()

__call__(x)

Runs the einsum operation.

Attributes

input_axes

output_axes

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x: tuple[named_axes.NamedArray, ...]) named_axes.NamedArray[source]#

Runs the einsum operation.