ApplyAttentionMask#

class penzai.nn.attention.ApplyAttentionMask[source]#

Bases: Layer

Applies an attention mask to its input logit array.

This layer retrieves a causal attention mask from its side input, and uses it to mask its argument. Masked out values are replaced with the masked_out_value attribute, which is usually a large (but finite) negative value.

Variables:

mask (side_input.SideInputEffect[named_axes.NamedArray]) – A side input that provides the attention mask to apply to the input attention scores. This side input should be provided as a boolean array that is broadcastable with the input.

Methods

__init__(mask, masked_out_value)

from_config(mask_tag[, masked_out_value])

Creates an ApplyAttentionMask layer from a tag and a mask value.

__call__(x)

Applies the attention mask to the input array.

Attributes

mask

masked_out_value

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x: named_axes.NamedArray) named_axes.NamedArray[source]#

Applies the attention mask to the input array.

Parameters:

x – The input array to mask. Usually the matrix of query-key dot products.

Returns:

An adjusted matrix of logits, where any value where the mask is False has been replaced with the masked_out_value argument.

classmethod from_config(mask_tag: Any, masked_out_value: jax.typing.ArrayLike = -2.3819763e+38) ApplyAttentionMask[source]#

Creates an ApplyAttentionMask layer from a tag and a mask value.

Parameters:
  • mask_tag – Side input tag for the mask side input. This should be used to identify the sdie inputs that correspond to the same attention mask throughout the model.

  • masked_out_value – The value to replace masked out values with. This is usually a large (but finite) negative value, so that it maps to a negligible attention weight in a numerically stable way.

Returns:

A new ApplyAttentionMask layer with the given configuration.