InterceptedFlaxScopeData

InterceptedFlaxScopeData#

class penzai.toolshed.unflaxify.InterceptedFlaxScopeData[source]#

Bases: Struct

A frozen representation of data in a particular Flax scope.

Flax implements its modules using a “functional core” which is a stateful manager of variables, parameters, and random keys for a module and all its submodules. This class represents a “Penzai view” of the data held in the scope for a particular module, not including its submodules.

Variables:
  • parameters (dict[str, pz.nn.ParameterLike[Any]]) – The collection of named parameters used directly by this module (not a submodule), represented as Penzai parameters. If this method was called multiple times, the parameters may be shared parameter references.

  • variables (dict[str, dict[str, pz.de.LocalStateEffect]]) – The collection of other variables used directly by this module (not a submodule), represented as Penzai state effects.

  • immutable_variables (dict[str, dict[str, Any]]) – The collection of immutable variables used directly by this module (not a submodule).

  • rngs (dict[str, pz.de.RandomEffect]) – The collection of RNGs used by this module method. Note that the random numbers generated by Penzai will NOT exactly match the random numbers generated by Flax, because Flax has custom logic for splitting and seeding RNGs that is not easy to directly reproduce in Penzai.

Methods

__init__(parameters, variables, ...)

Attributes

parameters

variables

immutable_variables

rngs

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.