Attention#
- class penzai.nn.attention.Attention[source]#
Bases:
LayerA basic attention combinator.
An attention layer contains five subcomputations, for computing queries, keys, and values, combining queries and keys into attention weights, and combining attention weights and values into an output. This class abstracts away the dataflow patterns common to all attention layers, and leaves the details of the actual computations to the sublayers.
- Variables:
input_to_query (layer_base.LayerLike) – A layer that maps the input to an array of queries.
input_to_key (layer_base.LayerLike) – A layer that maps the input to an array of keys.
input_to_value (layer_base.LayerLike) – A layer that maps the input to an array of values.
query_key_to_attn (layer_base.LayerLike) – A layer that maps a tuple of (queries, keys) to attention weights.
attn_value_to_output (layer_base.LayerLike) – A layer that maps a a tuple of (attention weights, values) to a final output.
Methods
__init__(input_to_query, input_to_key, ...)__call__(x)Runs the attention computation.
Attributes
input_to_queryinput_to_keyinput_to_valuequery_key_to_attnattn_value_to_outputInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
input_structure()Returns the input structure of this layer.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
output_structure()Returns the output structure of this layer.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray) named_axes.NamedArray[source]#
Runs the attention computation.
- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
- Returns:
The final output of the
attn_value_to_outputsublayer.