scan#
- penzai.core.named_axes.scan(f: Callable[[Any, Any], Any], axis: AxisName, init, xs=None, **scan_kwargs) Any[source]#
Scan a function over a named array axis while carrying along state.
This function wraps
jax.lax.scanto allow scanning over a named axis instead of over the leading axis ofxs. The inputsxsmust containNamedArray(orNamedArrayView) instances, andfshould (usually) take and return named arrays also, with the difference that each value inxswill be missing the named axisaxisthat is being scanned over.fwill be called with slices ofxsalong the given named axis.When
xsand the output offare singleNamedArrayinstances, the semantics ofscanare given roughly bydef scan(f, axis, init, xs): carry = init ys = [] for i in range(xs.named_shape[axis]): x = xs[{axis: i}] carry, y = f(carry, x) ys.append(y) return carry, pz.nx.stack(ys, axis)
This is a convenience function that is equivalent to calling
jax.lax.scanin combination with the necessarytag,untag, andwith_positional_prefixcalls.- Parameters:
f – The function to scan. As in
jax.lax.scan, this function should have signaturec -> a -> (c, b), wherecis the carry state,ais the current element ofxs, andbis the result of the scan.axis – The name of the axis to scan over.
init – The initial loop carry value of type
c, which can be any JAX PyTree. This value must have the same structure as the first element of the pair returned byf.xs – The value or tree of values over which to scan. Must contain Penzai named arrays, each of which should have the named axis
axis.**scan_kwargs – Additional keyword arguments to pass to
jax.lax.scan. In particular, ifxsis None, this should include the keyword argumentlengthto specify the length of the scan.
- Returns:
A pair
(final_carry, stacked_outputs), wherefinal_carryis the final loop carry value, andstacked_outputsis a named array or tree of named arrays that represents the second outputs off, stacked over the named axisaxis.