InterceptedFlaxModuleMethod#

class penzai.toolshed.unflaxify.InterceptedFlaxModuleMethod[source]#

Bases: Layer

A representation of an intercepted Flax module method call.

An InterceptedFlaxModule captures the logic that runs when you call a single Flax module method, and re-ifies its children and varibles so that they are accessible in the PyTree structure of the model.

Variables:
  • module (flax.linen.Module) – The unbound Flax module.

  • method_name (str) – The name of the method being called.

  • scope_data (InterceptedFlaxScopeData | None) – Data associated with this Flax module’s scope, including parameters, variables, and random keys used directly by this module (not including its submodules). Can be None if this module does not have any parameters or variables of its own and instead merely defers to its submodules.

  • submodule_calls (dict[tuple[int, str], pz.nn.Layer]) – The collection of all submodule calls made by this module method, in call order. Each call is re-ified as a Penzai layer and can be patched to run arbitrary logic instead of the original Flax module method.

Methods

__init__(module, method_name, scope_data, ...)

treescope_color()

__call__(args_and_kwargs[, random_streams])

Calls the intercepted method with the given arguments and kwargs.

Attributes

module

method_name

scope_data

submodule_calls

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

__call__(args_and_kwargs: ArgsAndKwargs, random_streams: dict[str, pz.RandomStream] | None = None, **side_inputs) Any[source]#

Calls the intercepted method with the given arguments and kwargs.

Parameters:
  • args_and_kwargs – The positional and keyword arguments passed to the Flax method.

  • random_streams – Random number generators. Should include a key for each RNG used by the Flax module method. Note that RNG states will NOT necessarily be the same between Flax and Penzai.

  • **side_inputs – Other Penzai side inputs for the model, which will be forwarded to other penzai layers. (Only used if penzai layers that use side inputs have been inserted.)

Returns:

Whatever the output of the Flax method is.