LinearizeAndAdjust

LinearizeAndAdjust#

class penzai.toolshed.model_rewiring.LinearizeAndAdjust[source]#

Bases: Layer

Linearizes and evaluates a model around two adjusted inputs.

This layer splits its input into two paths, and allows each path to be adjusted independently. Then, these two inputs are used to construct and evaluate a first-order approximation of the target layer: the first adjusted input is used as the linearization point, and the second adjusted input is used as the point of evaluation.

If linearize_around and evaluate_at are the same, this will behave the same as an ordinary sequence of operations, since evaluating a linear function at the linearization point is the same as evaluating the target function normally.

Methods

__init__(linearize_around, evaluate_at, target)

__call__(inputs, **side_inputs)

Attributes

linearize_around

evaluate_at

target

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.