tree_util#
Additional tree functionality related to jax.tree_util.
Classes
Subclass-friendly variant of jax.tree_util.GetAttrKey. |
Functions
|
Constructs a pretty name from a keypath and an object. |
Flattens a PyTree exactly one level, or returns None if it's not a PyTree. |