IsolatedSubmodel

Contents

IsolatedSubmodel#

class penzai.toolshed.isolate_submodel.IsolatedSubmodel[source]#

Bases: Struct

An isolated part of a model, with saved inputs, outputs, and variables.

Variable values will also be frozen at the state they had when the model was called, allowing deterministic re-execution of the submodel. To re-play the submodel, you can run

result, final_var_values = isolated.submodel.stateless_call(
    isolated.initial_var_values,
    isolated.saved_arg,
    **isolated.saved_side_inputs
)
Variables:
  • submodel (pz.nn.Layer) – An individual layer from the larger model. This will match the layer that was originally selected, with a few modifications: parameters will be frozen, and any variables will be unbound and replaced with variable slots.

  • saved_arg (Any) – Positional argument that was passed to the submodel when we isolated it. If the original argument contained variables (although this is rare), this will contain variable slots.

  • saved_side_inputs (dict[str, Any]) – Input that was passed to the submodel when we isolated it. If the original argument contained variables (e.g. for random number generators), this will contain variable slots.

  • saved_output (Any) – Saved output that the submodel should produce when called with saved_arg.

  • initial_var_values (tuple[pz.StateVariableValue, ...] | None) – Saved variable values at the point when the submodel was called, not including parameters (which are assumed immutable).

  • final_var_values (tuple[pz.StateVariableValue, ...] | None) – Saved variable values at the point after the submodel was called, ot including parameters (which are assumed immutable).

Methods

__init__(submodel, saved_arg, ...)

Attributes

submodel

saved_arg

saved_side_inputs

saved_output

initial_var_values

final_var_values

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.