InterceptedFlaxModuleMethod#

class penzai.toolshed.unflaxify.InterceptedFlaxModuleMethod[source]#

Bases: Layer

A representation of an intercepted Flax module method call.

An InterceptedFlaxModule captures the logic that runs when you call a single Flax module method, and re-ifies its children and varibles so that they are accessible in the PyTree structure of the model.

Variables:
  • module (flax.linen.Module) – The unbound Flax module.

  • method_name (str) – The name of the method being called.

  • scope_data (InterceptedFlaxScopeData | None) – Data associated with this Flax module’s scope, including parameters, variables, and random keys used directly by this module (not including its submodules). Can be None if this module does not have any parameters or variables of its own and instead merely defers to its submodules.

  • submodule_calls (dict[tuple[int, str], pz.LayerLike]) – The collection of all submodule calls made by this module method, in call order. Each call is re-ified as a Penzai layer and can be patched to run arbitrary logic instead of the original Flax module method.

Methods

__init__(module, method_name, scope_data, ...)

treescope_color()

__call__(args_and_kwargs)

Calls the intercepted method with the given arguments and kwargs.

Attributes

module

method_name

scope_data

submodule_calls

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

__call__(args_and_kwargs: ArgsAndKwargs) Any[source]#

Calls the intercepted method with the given arguments and kwargs.

Parameters:

args_and_kwargs – The positional and keyword arguments passed to the Flax method.

Returns:

Whatever the output of the Flax method is.