Utility to automatically apply nmap to functions in a module.

pz.nx.nmap can be used to “lift” any positional JAX function into a named-axis variant, with simple semantics: operations always vectorize over the named axes, and the function only sees the positional ones. However, it can be annoying to wrap each JAX function with nmap manually.

This module provides syntactic sugar for automatically wrapping all callables in a module with nmap, by exposing a wrapper object that stands in for the original module. Accessing an attribute on the wrapper accesses that same attribute on the underlying module; then, the result is either wrapped in nmap or recursively wrapped using the module wrapper, depending on the type of the retrieved value.

The result is that, if you define

njax = auto_nmap.wrap_module(jax)
njnp = auto_nmap.wrap_module(jnp)

then you can write code like

njax.lax.top_k(my_array, k=10)

instead of

pz.nx.wrap(jax.lax.top_k)(my_array, k=10)

You can also directly use ordinary array constructors to construct NamedArrays, e.g.

njax.array([1, 2, 3]).tag("foo")
njnp.linspace(-0.5, 0.5, 10).tag("bar")
njax.random.uniform(key, (10, 10)).tag("baz", "qux)



Wrapper for a module that automatically `nmap`s callables.



Wraps a module to automatically apply named_axes.nmap to callables.