ApplyCausalAttentionMask#

class penzai.nn.attention.ApplyCausalAttentionMask[source]#

Bases: Layer

Builds and applies a causal attention mask based on token positions.

This layer retrieves the token positions from its side input, and uses them to build a causal attention mask. Masked out values are replaced with the masked_out_value attribute, which is usually a large (but finite) negative value.

Variables:
  • masked_out_value (jax.typing.ArrayLike) – The value to substitute for masked-out locations.

  • query_positions_input_name (str) – Key in the side input dictionary to use to identify the query token positions, which should be an integer array with the seq_axis axis.

  • kv_positions_input_name (str) – Key in the side input dictionary to use to identify the key/value token positions, which should be an integer array the seq_axis axis. (This axis will be renamed to match kv_seq_axis.)

  • seq_axis (str) – Name of the sequence axis, which should be present in both the query and key/value token position side inputs.

  • kv_seq_axis (str) – Name of the key/value sequence axis, which represents the keys and values in the input logits array.

Methods

__init__(masked_out_value[, ...])

__call__(x, **side_inputs)

Applies the attention mask to the input array.

Attributes

kv_positions_input_name

kv_seq_axis

query_positions_input_name

seq_axis

masked_out_value

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray[source]#

Applies the attention mask to the input array.

Parameters:
  • x – The input array to mask. Usually the matrix of query-key dot products.

  • **side_inputs – Side inputs. Must include query_positions_input_name and kv_positions_input_name.

Returns:

An adjusted matrix of logits, where any value where the mask is False has been replaced with the masked_out_value argument.