auto_nmap#
Utility to automatically apply nmap
to functions in a module.
pz.nx.nmap
can be used to “lift” any positional JAX function into a named-axis
variant, with simple semantics: operations always vectorize over the named axes,
and the function only sees the positional ones. However, it can be annoying to
wrap each JAX function with nmap
manually.
This module provides syntactic sugar for automatically wrapping all callables
in a module with nmap
, by exposing a wrapper object that stands in for the
original module. Accessing an attribute on the wrapper accesses that same
attribute on the underlying module; then, the result is either wrapped in nmap
or recursively wrapped using the module wrapper, depending on the type of the
retrieved value.
The result is that, if you define
njax = auto_nmap.wrap_module(jax)
njnp = auto_nmap.wrap_module(jnp)
then you can write code like
njax.lax.top_k(my_array, k=10)
njnp.linalg.eigh(my_array)
instead of
pz.nx.wrap(jax.lax.top_k)(my_array, k=10)
pz.nx.wrap(jnp.linalg.eigh)(my_array)
You can also directly use ordinary array constructors to construct NamedArrays, e.g.
njax.array([1, 2, 3]).tag("foo")
njnp.linspace(-0.5, 0.5, 10).tag("bar")
njax.random.uniform(key, (10, 10)).tag("baz", "qux)
Classes
Wrapper for a module that automatically applies |
Functions
|
Wraps a module to automatically apply |