ApplyRoPEToSubset#
- class penzai.nn.embeddings.ApplyRoPEToSubset[source]#
Bases:
Layer
Adjusts a subset of embeddings using rotary position embeddings (RoPE).
This is like
ApplyRoPE
, but only applies to a subset of dimensions, for compatibility with the GPT-NeoX configuration and similar models.- Variables:
embedding_axis (str) – The axis name of the input that contains the embedding vector (e.g. “embedding” or “projection”).
max_wavelength (float) – The maximum wavelength of the periodic positional embeddings.
rope_subset_size (int) – Size of the prefix of the embedding axis that we should apply rotary embeddings to. The suffix will be left unchanged.
positions_input_name (str) – Key for the side input that provides the position of each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis.
Methods
__init__
(embedding_axis, max_wavelength, ...)__call__
(inputs, **side_inputs)Attributes
embedding_axis
max_wavelength
rope_subset_size
positions_input_name
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.