Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Visualizing NDArrays with Treescope#

High-dimensional NDArray (or tensor) data is common in many machine learning settings, but most plotting libraries are designed for either 2D image data or 1D time series data. Penzai’s pretty printer (treescope) includes a powerful arbitrarily-high-dimensional-array visualizer designed to make it easy to quickly summarize NDArrays without having to write manual plotting logic.

Setup#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np

import IPython
import penzai
from penzai import pz

Visualizing NDArrays with pz.ts.render_array#

Treescope includes a powerful array renderer, defined in penzai.treescope.arrayviz and aliased to pz.ts.render_array for easy use. It is designed to make faceted, interactive visualizations of \(N\)-dimensional arrays.

Visualizing numeric data and customizing colormaps#

Arrays can be directly rendered using default settings by passing them to pz.ts.render_array:

help(pz.ts.render_array)
my_array = np.cos(np.arange(300).reshape((10,30)) * 0.2)

pz.ts.render_array(my_array)

Things to notice:

  • The visualization is interactive! (Try zooming in and out, hovering over the array to inspect individual elements, or clicking to remember a particular element.)

  • The shape of the array can be read off by looking at the axis labels.

  • Pixels are always square in arrayviz renderings. (In fact, they are always exactly 7 pixels by 7 pixels at zoom level 1.)

The default rendering strategy uses a diverging colormap centered at zero, with blue for positive numbers and red for negative ones, to show you the absolute magnitude and sign of the array. You can toggle to a relative mode by passing the argument around_zero=False:

pz.ts.render_array(my_array, around_zero=False)

You can also customize the upper and lower bounds of the colormap by passing vmin and/or vmax:

pz.ts.render_array(my_array, vmax=0.7)

In this case, the array has values outside of our specified colormap bounds; those out-of-bounds values are rendered with “+” and “-” to indicate that they’ve been clipped.

Since we didn’t pass around_zero=False, it automatically set vmin to -vmax for us. You can choose to set both explicitly too:

pz.ts.render_array(my_array, vmin=-0.1, vmax=0.7)

If you want to customize the way colors are rendered, you can pass a custom colormap as a list of (R, G, B) color tuples:

import palettable
pz.ts.render_array(my_array, colormap=palettable.matplotlib.Inferno_20.colors)
pz.ts.render_array(my_array, colormap=palettable.cmocean.sequential.Speed_20.colors)

Visualizing high-dimensional arrays and NamedArrays#

So far we’ve been looking at an array with two axes, but arrayviz works out-of-the-box with arbitrarily high-dimensional arrays as well:

my_4d_array = np.cos(np.arange(5*6*7*8).reshape((5,6,7,8)) * 0.1)
pz.ts.render_array(my_4d_array)

For high-dimensional arrays, the individual axis labels indicate which level of the plot corresponds to which axis. Above, each 7x8 square facet represents a slice my_4d_array[i,j,:,:], with individual pixels ranging along axis 2 and axis 3; this is denoted by the axis2 and axis3 labels for that facet. The six columns correspond to slices along axis 1, and the five rows correspond to slices along axis 0, as denoted by the outermost labels for those axes.

You can control which axes get assigned to which direction if you want, specified from innermost to outermost:

pz.ts.render_array(my_4d_array, columns=[2, 0, 1])

Note that the gap between the “axis0” groups is twice as large as the gap between “axis2” groups, so that you can visually distinguish the groups.

Arrayviz can also visualize NamedArray, and takes labels from them. This means that, if your code is written with NamedArrays, you get labeled visualizations for free! This applies both to axes that have been tagged with a name and axes that haven’t. (See the NamedArray tutorial for more information on how NamedArrays work in penzai.)

col = pz.nx.wrap(np.linspace(-2, 2, 31)).tag("col")
row = pz.nx.wrap(np.linspace(-2, 2, 31)).tag("row")
sign = pz.nx.wrap(np.array([1, -1])).tag("sign")

my_named_array = sign * (col**2 + row**2)

pz.ts.render_array(my_named_array, columns=["col", "sign"])
pz.ts.render_array(my_named_array.untag("sign"))

Identifying extreme or invalid array values#

By default, arrayviz tries to configure the colormap to show interesting detail, clipping outliers. Specifically, it limits the colormap to 3 standard deviations away from the mean (or, technically, from zero if around_zero is set):

my_outlier_array = np.cos(np.arange(300).reshape((10,30)) * 0.2)
my_outlier_array[4, 2] = 10.0
pz.ts.render_array(my_outlier_array)

Arrayviz also annotates any invalid array values by drawing annotations on top of the visualization:

numerator = np.linspace(-5, 5, 31)
denominator = np.linspace(-1, 1, 13)
array_with_infs_and_nans = numerator[None, :] / denominator[:, None]
pz.ts.render_array(array_with_infs_and_nans)

Above, “I” (white on a blue background) denotes positive infinity, “-I” (white on a red background) denotes negative infinity, and “X” (in magenta on a black background) denotes NaN. (You can also see the outlier-clipping behavior clipping a few of the largest finite values here.)

This works in relative mode too:

pz.ts.render_array(array_with_infs_and_nans, around_zero=False)

If you want, you can mask out data by providing a “valid mask”. Only values where the mask is True will be rendered; masked-out data is shown in gray with black dots.

valid_mask = np.isfinite(array_with_infs_and_nans) & (np.abs(array_with_infs_and_nans) < 10)
pz.ts.render_array(
    array_with_infs_and_nans,
    valid_mask=valid_mask,
)

Visualizing categorical data#

Arrayviz also supports rendering categorical data, even with very high numbers of categories. Data with a discrete (integer or boolean) dtype is rendered as categorical by default, with different colors for different categories:

pz.ts.render_array(np.arange(10))
pz.ts.render_array(np.array([True, False, False, True, True]))

The values from 0 to 9 are rendered with solid colors, with 0 represented as white. Larger numbers are rendered using nested box patterns, with one box per digit of the number, and the color of the box indicating the value of the digit:

pz.ts.render_array(np.arange(1000).reshape((10,100)))
pz.ts.render_array(
    pz.nx.wrap(jnp.arange(20)).tag("a")
    * pz.nx.wrap(jnp.arange(20)).tag("b")
)

You can also render a single integer on its own, if you want (sometimes useful for custom visualizations). Arrayviz supports integers with up to 7 digits.

pz.ts.integer_digitbox(42, size="30px")
pz.ts.integer_digitbox(1234, size="30px")
pz.ts.integer_digitbox(7654321, size="30px")

Negatigve integers render the same way as positive ones, but with a black triangle in the corner indicating the sign:

pz.ts.render_array(np.arange(21 * 21).reshape((21, 21)) - 220)

If your data has a discrete dtype but you don’t want to render it as categorical, you can pass the continuous flag to render it as numeric instead:

pz.ts.render_array(np.arange(21 * 21).reshape((21, 21)) - 220, continuous=True)

Adding axis item labels#

For some arrays, it can be useful to associate labels with the individual indices along each axis. For instance, we might want to label a “classes” axis with each individual class, or a “sequence” axis with the tokens of the sequence.

Arrayviz allows you to pass this kind of metadata as an extra argument, and will show it to you when you hover over or click on elements of the array with your mouse.

For positional axes, you can pass any subset of the axes by position:

# Try hovering or clicking:
pz.ts.render_array(
    np.sin(np.linspace(0, 100, 12 * 5 * 7)).reshape((12, 5, 7)),
    axis_item_labels={
        1: ["foo", "bar", "baz", "qux", "xyz"],
        0: ["one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve"],
    }
)

For named axes, you can pass labels by name. Irrelevant labels are simply ignored.

col = pz.nx.wrap(np.linspace(-2, 2, 15)).tag("col")
row = pz.nx.wrap(np.linspace(-2, 2, 15)).tag("row")
sign = pz.nx.wrap(np.array([1, -1])).tag("sign")

my_named_array = sign * (col**2 + row**2)

# Try hovering or clicking:
pz.ts.render_array(
    my_named_array,
    columns=["col", "sign"],
    axis_item_labels={
        "sign": ["positive", "negative"],
        "classes": ["cat", "dog", "mouse", "house"],
    }
)

Slicing and “scrubbing” with sliders#

It’s sometimes useful to only look at individual slices of a large array at a time, instead of viewing them all at once. In addition to the columns and rows arguments, arrayviz supports a sliders argument, which will display a slider for those axes and allow you to “scrub” through indices in it:

time = pz.nx.wrap(jnp.arange(100)).tag("time")
col = pz.nx.wrap(np.linspace(-2, 2, 15)).tag("col")
row = pz.nx.wrap(np.linspace(-2, 2, 15)).tag("row")

values_over_time = pz.nx.nmap(jax.nn.sigmoid)(
    0.05 * time - 2 - row - pz.nx.nmap(jnp.sin)(2 * col - 0.1 * time)
)

# Try sliding the slider:
pz.ts.render_array(
    values_over_time,
    columns=["col"],
    sliders=["time"],
)

You can even put sliders for multiple axes simultaneously, if you want:

row_wavelength = pz.nx.wrap(4 * jnp.arange(10) + 4).tag("row_wavelength")
col_wavelength = pz.nx.wrap(4 * jnp.arange(10) + 4).tag("col_wavelength")
col = pz.nx.wrap(np.arange(15)).tag("col")
row = pz.nx.wrap(np.arange(15)).tag("row")

values = (
    pz.nx.nmap(jnp.sin)(2 * np.pi * row / row_wavelength)
    * pz.nx.nmap(jnp.sin)(2 * np.pi * col / col_wavelength)
)

# Try sliding the slider:
pz.ts.render_array(
    values,
    columns=["col"],
    sliders=["row_wavelength", "col_wavelength"],
    axis_item_labels={
        "row_wavelength": [str(v) for v in row_wavelength.untag("row_wavelength").unwrap()],
        "col_wavelength": [str(v) for v in col_wavelength.untag("col_wavelength").unwrap()],
    }
)

Note: Memory usage#

One caveat to using arrayviz: whenever you render an array, the entire array is serialized, saved directly into the notebook output cell, and then loaded into your browser’s memory! That’s true even if you use sliders; although only part of your array is visible, all of the data is there in the notebook and in your local browser, so that it can update the view when you move the slider.

This can sometimes be useful, since it means the visualization does not require Colab/IPython to be connected, and won’t mess up any of your Python interpreter’s state. On the other hand, it’s easy to end up with very large Colab notebooks this way, and if you have many visualizations open, it can cause your web browser to bog down a bit. For a sense of scale here, a visualization of a 1000 x 1000 array adds about 5 megabytes to the size of your notebook. (Arrayviz will still happily render an array of that size, though!)

Given this, it’s usually a good idea to avoid saving visualizations of very large arrays into the notebook. One way to do this is to turn on “Omit code cell output when saving this notebook” mode in Colab to avoid saving output from any cell.

Using arrayviz in treescope#

In JAX programs, NDArrays often occur as parts of large JAX-compatible data structures (PyTrees). If desired, you can use arrayviz to render these arrays inside a larger treescope rendering, by turning on automatic array visualization.

(You might want to read the tutorial on treescope pretty-printing before reading this section, if you haven’t already.)

Let’s start by registering treescope as the default pretty-printer:

pz.ts.register_as_default()

Ordinarily, treescope shows text representations of NDArrays, and lets you click to expand them:

IPython.display.display({
    "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
    "bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
})

You can enable array visualization by wrapping your display calls in an “autovisualizer” scope:

with pz.ts.active_autovisualizer.set_scoped(
    pz.ts.ArrayAutovisualizer()
):
  IPython.display.display({
      "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
      "bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
  })

In Colab / IPython, you can optionally register the %%autovisualize cell magic that runs an IPython cell inside the autovisualizer scope:

pz.ts.register_autovisualize_magic()
%%autovisualize
{
  "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
  "bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
}

Small arrays are shown in their entirety. Larger arrays are truncated to show only a subset of elements along each axis, to prevent visualizations from getting too large.

%%autovisualize
{
  "foo": pz.nx.wrap(jnp.arange(10_000)).tag("a") * pz.nx.wrap(jnp.arange(20)).tag("b"),
  "bar": np.sin(np.arange(250000) * 0.1).reshape((500,500))
}

If you want to enable array visualization all of the time, you can set up an automatic visualizer for your whole session:

# Makes it possible to set penzai's contextual configuration options interactively
pz.enable_interactive_context()

# (Later you can call `pz.disable_interactive_context()` to
# reset all of them to their default values.)
pz.ts.active_autovisualizer.set_interactive(
    pz.ts.ArrayAutovisualizer()
)

Now every value you display, or return from a cell, will be rendered with arrayviz:

np.arange(10)

ArrayAutovisualizer takes some configuration options that allow you to configure the summarization threshold and number of edge items visualized; see help(ArrayAutovisualizer) for more info.

If you’ve enabled automatic array visualization, you can disable it in a specific cell using %%autovisualize None:

%%autovisualize None
np.arange(10)

Custom treescope visualizations with autovisualizers and figure inlining#

Automatic array visualization is a special case of a more general treescope feature, which lets you render arbitrary figures at arbitrary points in pretty-printed PyTrees. To customize automatic visualization, you define an autovisualizer function, with the following signature:

def autovisualizer_fn(
    value: Any,
    path: tuple[Any, ...] | None,
) -> pz.ts.IPythonVisualization | pz.ts.ChildAutovisualizer | None:
  ...

This function will be called on every subtree of the rendered tree, and can return pz.ts.IPythonVisualization(some_figure) to replace the subtree with a visualization, or None to process the subtree normally. (It can also return pz.ts.ChildAutovisualizer if the subtree should be rendered with a different autovisualizer.)

For instance, we can write an autovisualizer that always formats arrays in continuous mode:

def my_continuous_autovisualizer(
    value: Any,
    path: tuple[Any, ...] | None,
):
  if isinstance(value, (np.ndarray, pz.nx.NamedArray)):
    return pz.ts.IPythonVisualization(
        pz.ts.render_array(value, continuous=True, around_zero=False))
with pz.ts.active_autovisualizer.set_scoped(
    my_continuous_autovisualizer
):
  IPython.display.display({
      "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
      "bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
  })

Or, add additional metadata:

def my_verbose_autovisualizer(
    value: Any,
    path: tuple[Any, ...] | None,
):
  if isinstance(value, (np.ndarray, pz.nx.NamedArrayBase)):
    if isinstance(value, pz.nx.NamedArrayBase):
      size = value.data_array.size
    else:
      size = value.size
    return pz.ts.IPythonVisualization(
        pz.ts.inline(
            "Hello world!\n",
            pz.ts.render_array(value),
            f"\nThis array contains {size} elements and has Python id {id(value):,}, which you could tokenize as  ",
            pz.ts.integer_digitbox(id(value) // 1000000000000),
            "   ", pz.ts.integer_digitbox((id(value) // 1000000000) % 1000),
            "   ", pz.ts.integer_digitbox((id(value) // 1000000) % 1000),
            "   ", pz.ts.integer_digitbox((id(value) // 1000) % 1000),
            "   ", pz.ts.integer_digitbox(id(value) % 1000),
            f"\nThe path to this node is {path}",
        )
    )
with pz.ts.active_autovisualizer.set_scoped(
    my_verbose_autovisualizer
):
  IPython.display.display({
      "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
      "bar": np.sin(np.arange(150) * 0.1).reshape((15,10))
  })

You can even render values using an external plotting library like plotly!

Treescope can inline any type of figure that has a rich HTML representation (specifically, any object that defines the magic _repr_html_ method expected by Colab’s IPython kernel.)

import plotly.express as px
def my_plotly_autovisualizer(
    value: Any,
    path: tuple[Any, ...] | None,
):
  if isinstance(value, (np.ndarray, jax.Array)):
    return pz.ts.IPythonVisualization(
        px.histogram(
            value.flatten(),
            width=400, height=200
        ).update_layout(
            margin=dict(l=20, r=20, t=20, b=20)
        )
    )
with pz.ts.active_autovisualizer.set_scoped(
    my_plotly_autovisualizer
):
  IPython.display.display({
      "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
      "bar": np.sin(np.arange(150) * 0.1).reshape((15,10))
  })

You can also pass custom visualizers to the %%autovisualize magic to let it handle the set_scoped boilerplate for you:

%%autovisualize my_plotly_autovisualizer
{
  "foo": pz.nx.wrap(jnp.arange(10)).tag("a") * pz.nx.wrap(jnp.arange(10)).tag("b"),
  "bar": np.sin(np.arange(150) * 0.1).reshape((15,10))
}