ApplyRoPE

ApplyRoPE#

class penzai.nn.embeddings.ApplyRoPE[source]#

Bases: Layer

Adjusts input embeddings using rotary position embeddings (RoPE).

Rotary position embeddings, proposed by Su et al. (2021), incorporate relative position information into the queries and keys of attention layers by applying periodic rotations to the elements of the query and key projections.

The ApplyRoPE layer can be inserted into an attention computation immediately before computing query-key dot products in order to add rotational position information to them.

See https://arxiv.org/abs/2104.09864.

Variables:
  • embedding_axis (str) – The axis name of the input that contains the embedding vector (e.g. “embedding” or “projection”).

  • max_wavelength (int) – The maximum wavelength of the periodic positional embeddings.

  • positions (side_input.SideInputEffect[named_axes.NamedArray]) – A side input that provides the position of each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis.

Methods

__init__(embedding_axis, max_wavelength, ...)

from_config(positions_tag, embedding_axis[, ...])

Constructs an ApplyRoPE layer for a given axis and side input tag.

input_structure()

output_structure()

__call__(inputs)

Attributes

embedding_axis

max_wavelength

positions

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

classmethod from_config(positions_tag: Any, embedding_axis: str, max_wavelength: int = 10000) ApplyRoPE[source]#

Constructs an ApplyRoPE layer for a given axis and side input tag.

Parameters:
  • positions_tag – Side input tag for the position side input. This should be used to identify the side inputs that should receive the same position information. This same tag should then (usually) be passsed to the pz.de.WithSideInputsFromInputTuple handler that actually provides this side input.

  • embedding_axis – Name of the axis that contains the embedding vector.

  • max_wavelength – The maximum wavelength of the periodic positional embeddings.

Returns:

A new ApplyRoPE layer with the given configuration.