ApplyCausalSlidingWindowAttentionMask#
- class penzai.nn.attention.ApplyCausalSlidingWindowAttentionMask[source]#
Bases:
LayerBuilds and applies a sliding-window attention mask based on token positions.
This layer retrieves the token positions from its side input, and uses them to build a causal sliding-window attention mask, where values at a distance of
window_sizeor further away from the current token are masked out. Masked out values are replaced with themasked_out_valueattribute, which is usually a large (but finite) negative value.- Variables:
masked_out_value (jax.typing.ArrayLike) – The value to substitute for masked-out locations.
sliding_window_size (int | jax.typing.ArrayLike) – The size of the sliding window.
query_positions_input_name (str) – Key in the side input dictionary to use to identify the query token positions, which should be an integer array with the
seq_axisaxis.kv_positions_input_name (str) – Key in the side input dictionary to use to identify the key/value token positions, which should be an integer array the
seq_axisaxis. (This axis will be renamed to matchkv_seq_axis.)seq_axis (str) – Name of the sequence axis, which should be present in both the query and key/value token position side inputs.
kv_seq_axis (str) – Name of the key/value sequence axis, which represents the keys and values in the input logits array.
Methods
__init__(masked_out_value, sliding_window_size)__call__(x, **side_inputs)Applies the attention mask to the input array.
Attributes
kv_positions_input_namekv_seq_axisquery_positions_input_nameseq_axismasked_out_valuesliding_window_sizeInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
bind_variables(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- __call__(x: named_axes.NamedArray, **side_inputs: Any) named_axes.NamedArray[source]#
Applies the attention mask to the input array.
- Parameters:
x – The input array to mask. Usually the matrix of query-key dot products.
**side_inputs – Side inputs. Must include
query_positions_input_nameandkv_positions_input_name.
- Returns:
An adjusted matrix of logits, where any value where the mask is False has been replaced with the
masked_out_valueargument.