scan#
- penzai.core.named_axes.scan(f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs) Any [source]#
Scan a function over a named array axis while carrying along state.
This function wraps
jax.lax.scan
to allow scanning over a named axis instead of over the leading axis ofxs
. The inputsxs
must containNamedArray
(orNamedArrayView
) instances, andf
should (usually) take and return named arrays also, with the difference that each value inxs
will be missing the named axisaxis
that is being scanned over.f
will be called with slices ofxs
along the given named axis.When
xs
and the output off
are singleNamedArray
instances, the semantics ofscan
are given roughly bydef scan(f, axis, init, xs): carry = init ys = [] for i in range(xs.named_shape[axis]): x = xs[{axis: i}] carry, y = f(carry, x) ys.append(y) return carry, pz.nx.stack(ys, axis)
This is a convenience function that is equivalent to calling
jax.lax.scan
in combination with the necessarytag
,untag
, andwith_positional_prefix
calls.- Parameters:
f – The function to scan. As in
jax.lax.scan
, this function should have signaturec -> a -> (c, b)
, wherec
is the carry state,a
is the current element ofxs
, andb
is the result of the scan.axis – The name of the axis to scan over.
init – The initial loop carry value of type
c
, which can be any JAX PyTree. This value must have the same structure as the first element of the pair returned byf
.xs – The value or tree of values over which to scan. Must contain Penzai named arrays, each of which should have the named axis
axis
.**scan_kwargs – Additional keyword arguments to pass to
jax.lax.scan
. In particular, ifxs
is None, this should include the keyword argumentlength
to specify the length of the scan.
- Returns:
A pair
(final_carry, stacked_outputs)
, wherefinal_carry
is the final loop carry value, andstacked_outputs
is a named array or tree of named arrays that represents the second outputs off
, stacked over the named axisaxis
.