LinearizeAndAdjust#
- class penzai.deprecated.v1.toolshed.model_rewiring.LinearizeAndAdjust[source]#
Bases:
Layer
Linearizes and evaluates a model around two adjusted inputs.
This layer splits its input into two paths, and allows each path to be adjusted independently. Then, these two inputs are used to construct and evaluate a first-order approximation of the target layer: the first adjusted input is used as the linearization point, and the second adjusted input is used as the point of evaluation.
If
linearize_around
andevaluate_at
are the same, this will behave the same as an ordinary sequence of operations, since evaluating a linear function at the linearization point is the same as evaluating the target function normally.Methods
__init__
(linearize_around, evaluate_at, target)__call__
(inputs)Attributes
linearize_around
evaluate_at
target
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
input_structure
()Returns the input structure of this layer.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
output_structure
()Returns the output structure of this layer.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.