ApplyRoPE

Contents

ApplyRoPE#

class penzai.nn.embeddings.ApplyRoPE[source]#

Bases: Layer

Adjusts input embeddings using rotary position embeddings (RoPE).

Rotary position embeddings, proposed by Su et al. (2021), incorporate relative position information into the queries and keys of attention layers by applying periodic rotations to the elements of the query and key projections.

The ApplyRoPE layer can be inserted into an attention computation immediately before computing query-key dot products in order to add rotational position information to them.

See https://arxiv.org/abs/2104.09864.

Variables:
  • embedding_axis (str) – The axis name of the input that contains the embedding vector (e.g. “embedding” or “projection”).

  • max_wavelength (float) – The maximum wavelength of the periodic positional embeddings.

  • positions_input_name (str) – Key for the side input that provides the position of each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis.

Methods

__init__(embedding_axis, max_wavelength, ...)

__call__(inputs, **side_inputs)

Attributes

embedding_axis

max_wavelength

positions_input_name

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.