unflaxify_apply

unflaxify_apply#

penzai.toolshed.unflaxify.unflaxify_apply(module: flax.linen.Module, variables: flax.typing.VariableDict, *dummy_args, rngs: flax.typing.PRNGKey | flax.typing.RNGSequences | None = None, method: Callable[..., Any] | str | None = None, mutable: flax.core.scope.CollectionFilter = False, **dummy_kwargs) InterceptedFlaxModuleMethod[source]#

Creates an InterceptedFlaxModuleMethod from applying a Flax module.

Note that this function is intended for interactive exploration and to help migrate Flax code to Penzai. It is not intended to be used in production code. Not all Flax features are supported yet; in particular, transformed layers are not supported and have not been tested.

Parameters:
  • module – The flax module to apply.

  • variables – A dictionary containing variables keyed by variable collections, with same interpretation as for flax.linen.Module.apply.

  • *dummy_args – Positional arguments passed to the specified apply method. These can be arbitrary values; their purpose is to enable tracing through the Flax logic.

  • rngs – A dict of PRNGKeys to initialize the PRNG sequences, with same interpretation as for flax.linen.Module.apply.

  • method – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • **dummy_kwargs – Keyword arguments passed to the specified apply method. These can be arbitrary values; their purpose is to enable tracing through the Flax logic.

Returns:

An intercepted version of the Flax module call, which can be manipulated using Penzai tools.