ApplyExplicitAttentionMask#
- class penzai.nn.attention.ApplyExplicitAttentionMask[source]#
Bases:
Layer
Applies an explicit attention mask to its input logit array.
This layer retrieves an attention mask from its side input, and uses it to mask its main argument. Masked out values are replaced with the
masked_out_value
attribute, which is usually a large (but finite) negative value.- Variables:
mask_input_name (str) – Key in the side input dictionary to use to identify the attention mask.
masked_out_value (jax.typing.ArrayLike) – The value to substitute for masked-out locations.
Methods
__init__
(mask_input_name, masked_out_value)__call__
(x, **side_inputs)Applies the attention mask to the input array.
Attributes
mask_input_name
masked_out_value
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray [source]#
Applies the attention mask to the input array.
- Parameters:
x – The input array to mask. Usually the matrix of query-key dot products.
**side_inputs – Side inputs. Must include
mask_input_name
.
- Returns:
An adjusted matrix of logits, where any value where the mask is False has been replaced with the
masked_out_value
argument.