Attention#
- class penzai.nn.attention.Attention[source]#
Bases:
Layer
A basic attention combinator.
An attention layer contains five subcomputations, for computing queries, keys, and values, combining queries and keys into attention weights, and combining attention weights and values into an output. This class abstracts away the dataflow patterns common to all attention layers, and leaves the details of the actual computations to the sublayers.
- Variables:
input_to_query (layer_base.Layer) – A layer that maps the input to an array of queries.
input_to_key (layer_base.Layer) – A layer that maps the input to an array of keys.
input_to_value (layer_base.Layer) – A layer that maps the input to an array of values.
query_key_to_attn (layer_base.Layer) – A layer that maps a tuple of (queries, keys) to attention weights.
attn_value_to_output (layer_base.Layer) – A layer that maps a a tuple of (attention weights, values) to a final output.
Methods
__init__
(input_to_query, input_to_key, ...)treescope_color
()__call__
(x, **side_inputs)Runs the attention computation.
Attributes
input_to_query
input_to_key
input_to_value
query_key_to_attn
attn_value_to_output
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray [source]#
Runs the attention computation.
- Parameters:
x – The input to the computation, which will be mapped to queries, keys, and values by the sublayers.
**side_inputs – Side inputs for all sublayers.
- Returns:
The final output of the
attn_value_to_output
sublayer.