ApplyRoPE#
- class penzai.deprecated.v1.nn.embeddings.ApplyRoPE[source]#
Bases:
Layer
Adjusts input embeddings using rotary position embeddings (RoPE).
Rotary position embeddings, proposed by Su et al. (2021), incorporate relative position information into the queries and keys of attention layers by applying periodic rotations to the elements of the query and key projections.
The
ApplyRoPE
layer can be inserted into an attention computation immediately before computing query-key dot products in order to add rotational position information to them.See https://arxiv.org/abs/2104.09864.
- Variables:
embedding_axis (str) – The axis name of the input that contains the embedding vector (e.g. “embedding” or “projection”).
max_wavelength (int) – The maximum wavelength of the periodic positional embeddings.
positions (side_input.SideInputEffect[named_axes.NamedArray]) – A side input that provides the position of each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis.
Methods
__init__
(embedding_axis, max_wavelength, ...)from_config
(positions_tag, embedding_axis[, ...])Constructs an
ApplyRoPE
layer for a given axis and side input tag.input_structure
()output_structure
()__call__
(inputs)Attributes
embedding_axis
max_wavelength
positions
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- classmethod from_config(positions_tag: Any, embedding_axis: str, max_wavelength: int = 10000) ApplyRoPE [source]#
Constructs an
ApplyRoPE
layer for a given axis and side input tag.- Parameters:
positions_tag – Side input tag for the position side input. This should be used to identify the side inputs that should receive the same position information. This same tag should then (usually) be passsed to the
pz.de.WithSideInputsFromInputTuple
handler that actually provides this side input.embedding_axis – Name of the axis that contains the embedding vector.
max_wavelength – The maximum wavelength of the periodic positional embeddings.
- Returns:
A new
ApplyRoPE
layer with the given configuration.