truncate_array_and_mask

truncate_array_and_mask#

penzai.treescope.ndarray_summarization.truncate_array_and_mask(array: jax.Array, mask: jax.Array, edge_items_per_axis: tuple[int | None, ...]) tuple[jax.Array, jax.Array][source]#

Truncates an array along the given axis names.

Parameters:
  • array – Array to truncate.

  • mask – Mask array, which must have the same number of dimensions as array, and whose axis sizes must be either 1 or the same as that axis of array (e.g. they are broadcast compatible).

  • edge_items_per_axis – Number of edge items to keep for each axis, ignoring any axes whose slices are already computed in prefix_slices.

Returns:

A tuple containing a truncated version of the array along with a valid mask. Values taken from the original array have the valid mask as True, and there is one extra element in the middle with valid as False (standing in for the omitted elements). The return value is always fully replicated, because we cannot guarantee that it is evenly sharded across devices, and this function is usually used immediately before copying to the host.