ApplyAttentionMask#
- class penzai.deprecated.v1.nn.attention.ApplyAttentionMask[source]#
Bases:
Layer
Applies an attention mask to its input logit array.
This layer retrieves a causal attention mask from its side input, and uses it to mask its argument. Masked out values are replaced with the
masked_out_value
attribute, which is usually a large (but finite) negative value.- Variables:
mask (side_input.SideInputEffect[named_axes.NamedArray]) – A side input that provides the attention mask to apply to the input attention scores. This side input should be provided as a boolean array that is broadcastable with the input.
Methods
__init__
(mask, masked_out_value)from_config
(mask_tag[, masked_out_value])Creates an
ApplyAttentionMask
layer from a tag and a mask value.__call__
(x)Applies the attention mask to the input array.
Attributes
mask
masked_out_value
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
input_structure
()Returns the input structure of this layer.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
output_structure
()Returns the output structure of this layer.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray) named_axes.NamedArray [source]#
Applies the attention mask to the input array.
- Parameters:
x – The input array to mask. Usually the matrix of query-key dot products.
- Returns:
An adjusted matrix of logits, where any value where the mask is False has been replaced with the
masked_out_value
argument.
- classmethod from_config(mask_tag: Any, masked_out_value: jax.typing.ArrayLike = -2.3819763e+38) ApplyAttentionMask [source]#
Creates an
ApplyAttentionMask
layer from a tag and a mask value.- Parameters:
mask_tag – Side input tag for the mask side input. This should be used to identify the sdie inputs that correspond to the same attention mask throughout the model.
masked_out_value – The value to replace masked out values with. This is usually a large (but finite) negative value, so that it maps to a negligible attention weight in a numerically stable way.
- Returns:
A new
ApplyAttentionMask
layer with the given configuration.