ApplyCausalAttentionMask#
- class penzai.nn.attention.ApplyCausalAttentionMask[source]#
Bases:
Layer
Builds and applies a causal attention mask based on token positions.
This layer retrieves the token positions from its side input, and uses them to build a causal attention mask. Masked out values are replaced with the
masked_out_value
attribute, which is usually a large (but finite) negative value.- Variables:
masked_out_value (jax.typing.ArrayLike) – The value to substitute for masked-out locations.
query_positions_input_name (str) – Key in the side input dictionary to use to identify the query token positions, which should be an integer array with the
seq_axis
axis.kv_positions_input_name (str) – Key in the side input dictionary to use to identify the key/value token positions, which should be an integer array the
seq_axis
axis. (This axis will be renamed to matchkv_seq_axis
.)seq_axis (str) – Name of the sequence axis, which should be present in both the query and key/value token position side inputs.
kv_seq_axis (str) – Name of the key/value sequence axis, which represents the keys and values in the input logits array.
Methods
__init__
(masked_out_value[, ...])__call__
(x, **side_inputs)Applies the attention mask to the input array.
Attributes
kv_positions_input_name
kv_seq_axis
query_positions_input_name
seq_axis
masked_out_value
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray [source]#
Applies the attention mask to the input array.
- Parameters:
x – The input array to mask. Usually the matrix of query-key dot products.
**side_inputs – Side inputs. Must include
query_positions_input_name
andkv_positions_input_name
.
- Returns:
An adjusted matrix of logits, where any value where the mask is False has been replaced with the
masked_out_value
argument.