ApplyRoPEToSubset

ApplyRoPEToSubset#

class penzai.nn.embeddings.ApplyRoPEToSubset[source]#

Bases: Layer

Adjusts a subset of embeddings using rotary position embeddings (RoPE).

This is like ApplyRoPE, but only applies to a subset of dimensions, for compatibility with the GPT-NeoX configuration and similar models.

Variables:
  • embedding_axis (str) – The axis name of the input that contains the embedding vector (e.g. “embedding” or “projection”).

  • max_wavelength (int) – The maximum wavelength of the periodic positional embeddings.

  • rope_subset_size (int) – Size of the prefix of the embedding axis that we should apply rotary embeddings to. The suffix will be left unchanged.

  • positions_input_name (str) – Key for the side input that provides the position of each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis.

Methods

__init__(embedding_axis, max_wavelength, ...)

__call__(inputs, **side_inputs)

Attributes

embedding_axis

max_wavelength

rope_subset_size

positions_input_name

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.