Attention

Attention#

class penzai.nn.attention.Attention[source]#

Bases: Layer

A basic attention combinator.

An attention layer contains five subcomputations, for computing queries, keys, and values, combining queries and keys into attention weights, and combining attention weights and values into an output. This class abstracts away the dataflow patterns common to all attention layers, and leaves the details of the actual computations to the sublayers.

Variables:
  • input_to_query (layer_base.LayerLike) – A layer that maps the input to an array of queries.

  • input_to_key (layer_base.LayerLike) – A layer that maps the input to an array of keys.

  • input_to_value (layer_base.LayerLike) – A layer that maps the input to an array of values.

  • query_key_to_attn (layer_base.LayerLike) – A layer that maps a tuple of (queries, keys) to attention weights.

  • attn_value_to_output (layer_base.LayerLike) – A layer that maps a a tuple of (attention weights, values) to a final output.

Methods

__init__(input_to_query, input_to_key, ...)

__call__(x)

Runs the attention computation.

Attributes

input_to_query

input_to_key

input_to_value

query_key_to_attn

attn_value_to_output

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x: named_axes.NamedArray) named_axes.NamedArray[source]#

Runs the attention computation.

Parameters:

x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.

Returns:

The final output of the attn_value_to_output sublayer.