LinearizeAndAdjust

LinearizeAndAdjust#

class penzai.toolshed.model_rewiring.LinearizeAndAdjust[source]#

Bases: Layer

Linearizes and evaluates a model around two adjusted inputs.

This layer splits its input into two paths, and allows each path to be adjusted independently. Then, these two inputs are used to construct and evaluate a first-order approximation of the target layer: the first adjusted input is used as the linearization point, and the second adjusted input is used as the point of evaluation.

If linearize_around and evaluate_at are the same, this will behave the same as an ordinary sequence of operations, since evaluating a linear function at the linearization point is the same as evaluating the target function normally.

Methods

__init__(linearize_around, evaluate_at, target)

__call__(inputs)

Attributes

linearize_around

evaluate_at

target

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.