scan

Contents

scan#

penzai.core.named_axes.scan(f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs) Any[source]#

Scan a function over a named array axis while carrying along state.

This function wraps jax.lax.scan to allow scanning over a named axis instead of over the leading axis of xs. The inputs xs must contain NamedArray (or NamedArrayView) instances, and f should (usually) take and return named arrays also, with the difference that each value in xs will be missing the named axis axis that is being scanned over. f will be called with slices of xs along the given named axis.

When xs and the output of f are single NamedArray instances, the semantics of scan are given roughly by

def scan(f, axis, init, xs):
  carry = init
  ys = []
  for i in range(xs.named_shape[axis]):
    x = xs[{axis: i}]
    carry, y = f(carry, x)
    ys.append(y)
  return carry, pz.nx.stack(ys, axis)

This is a convenience function that is equivalent to calling jax.lax.scan in combination with the necessary tag, untag, and with_positional_prefix calls.

Parameters:
  • f – The function to scan. As in jax.lax.scan, this function should have signature c -> a -> (c, b), where c is the carry state, a is the current element of xs, and b is the result of the scan.

  • axis – The name of the axis to scan over.

  • init – The initial loop carry value of type c, which can be any JAX PyTree. This value must have the same structure as the first element of the pair returned by f.

  • xs – The value or tree of values over which to scan. Must contain Penzai named arrays, each of which should have the named axis axis.

  • **scan_kwargs – Additional keyword arguments to pass to jax.lax.scan. In particular, if xs is None, this should include the keyword argument length to specify the length of the scan.

Returns:

A pair (final_carry, stacked_outputs), where final_carry is the final loop carry value, and stacked_outputs is a named array or tree of named arrays that represents the second outputs of f, stacked over the named axis axis.