Attention#
- class penzai.nn.attention.Attention[source]#
Bases:
LayerA basic attention combinator.
An attention layer contains five subcomputations, for computing queries, keys, and values, combining queries and keys into attention weights, and combining attention weights and values into an output. This class abstracts away the dataflow patterns common to all attention layers, and leaves the details of the actual computations to the sublayers.
- Variables:
input_to_query (layer_base.Layer) – A layer that maps the input to an array of queries.
input_to_key (layer_base.Layer) – A layer that maps the input to an array of keys.
input_to_value (layer_base.Layer) – A layer that maps the input to an array of values.
query_key_to_attn (layer_base.Layer) – A layer that maps a tuple of (queries, keys) to attention weights.
attn_value_to_output (layer_base.Layer) – A layer that maps a a tuple of (attention weights, values) to a final output.
Methods
__init__(input_to_query, input_to_key, ...)__call__(x, **side_inputs)Runs the attention computation.
Attributes
input_to_queryinput_to_keyinput_to_valuequery_key_to_attnattn_value_to_outputInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
bind_variables(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray[source]#
Runs the attention computation.
- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
**side_inputs – Side inputs for all sublayers.
- Returns:
The final output of the
attn_value_to_outputsublayer.