InterceptedFlaxScopeData#
- class penzai.toolshed.unflaxify.InterceptedFlaxScopeData[source]#
Bases:
Struct
A frozen representation of data in a particular Flax scope.
Flax implements its modules using a “functional core” which is a stateful manager of variables, parameters, and random keys for a module and all its submodules. This class represents a “Penzai view” of the data held in the scope for a particular module, not including its submodules.
- Variables:
parameters (dict[str, pz.Parameter[Any]]) – The collection of named parameters used directly by this module (not a submodule).
variables (dict[str, dict[str, pz.StateVariable[Any]]]) – The collection of other variables used directly by this module (not a submodule).
immutable_variables (dict[str, dict[str, Any]]) – The collection of immutable variable values used directly by this module (not a submodule).
rng_names (frozenset[str]) – Names for the RNGs used by this module method, which will be converted to Flax random number states. Note that the random numbers generated by Penzai will NOT exactly match the random numbers generated by Flax, because Flax has custom logic for splitting and seeding RNGs that is not easy to directly reproduce in Penzai.
Methods
__init__
(parameters, variables, ...)Attributes
parameters
variables
immutable_variables
rng_names
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.