tree_util#
Additional tree functionality related to jax.tree_util
.
Functions
|
Constructs a pretty name from a keypath and an object. |
Flattens a PyTree exactly one level, or returns None if it's not a PyTree. |
Additional tree functionality related to jax.tree_util
.
Functions
|
Constructs a pretty name from a keypath and an object. |
Flattens a PyTree exactly one level, or returns None if it's not a PyTree. |