ApplyExplicitAttentionMask#

class penzai.nn.attention.ApplyExplicitAttentionMask[source]#

Bases: Layer

Applies an explicit attention mask to its input logit array.

This layer retrieves an attention mask from its side input, and uses it to mask its main argument. Masked out values are replaced with the masked_out_value attribute, which is usually a large (but finite) negative value.

Variables:
  • mask_input_name (str) – Key in the side input dictionary to use to identify the attention mask.

  • masked_out_value (jax.typing.ArrayLike) – The value to substitute for masked-out locations.

Methods

__init__(mask_input_name, masked_out_value)

__call__(x, **side_inputs)

Applies the attention mask to the input array.

Attributes

mask_input_name

masked_out_value

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray[source]#

Applies the attention mask to the input array.

Parameters:
  • x – The input array to mask. Usually the matrix of query-key dot products.

  • **side_inputs – Side inputs. Must include mask_input_name.

Returns:

An adjusted matrix of logits, where any value where the mask is False has been replaced with the masked_out_value argument.