TransformerLM#
- class penzai.models.transformer.model_parts.TransformerLM[source]#
Bases:
Layer
Top-level transformer decoder wrapper.
This class is a simple wrapper that holds configuration data and runs safety checks.
- Variables:
body (pz.nn.Layer) – The implementation of the transformer.
metadata (TransformerMetadata) – The axis size and dtype info for the transformer.
Methods
__init__
(body, metadata)__call__
(tokens, *[, token_positions])Scores log-probabilities for the given inputs.
Attributes
body
metadata
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __call__(tokens: pz.nx.NamedArray, *, token_positions: pz.nx.NamedArray | None = None, **side_inputs) pz.nx.NamedArray [source]#
Scores log-probabilities for the given inputs.
- Parameters:
tokens – Array of token IDs, as an integer named array with a “seq” axis and possibly batch axes. Usually starts with the beginning-of-sequence token.
token_positions – Array of token positions, as an integer named array with a “seq” axis and possibly batch axes. Usually starts with 0. Inferred to start from 0 and increment along the “seq” axis if not provided.
**side_inputs – Side inputs, which will be forwarded to the body.
- Returns:
The final matrix of logits from the embedding decoding layer, which (in the normal configuration) will have axes “seq” and “vocabulary”.