Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Pretty-Printing With Treescope#

Treescope is Penzai’s interactive, color-coded HTML pretty-printer, designed for use in IPython notebooks. It’s designed to show you the structure of any model or tree of arrays, and is especially suited to looking at nested data structures.

As its name suggests, treescope is specifically focused on inspecting treelike data, represented as nodes (Python objects) that contain collections of child nodes (other Python objects). This is a good fit for JAX, since JAX’s PyTrees are already tree shaped, and also since JAX works with immutable data structures. It’s also a close match to the behavior of the ordinary Python repr, which produces a flat source-code-like view of an object and its contents. (Treescope has limited support for more general Python reference graphs and cyclic references as well, but it always renders them in a tree-like form.)

This notebook shows how to use the treescope pretty-printer to look at nested data structures and Penzai models.

Setup#

Let’s start by setting up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations

import typing
from typing import Any

import dataclasses

import jax
import jax.numpy as jnp
import numpy as np

import IPython
import penzai
from penzai import pz

Overview of Treescope#

How does treescope work in practice? Here’s an example. Ordinarily, if you try to inspect a nested object containing NDArrays, you get something pretty hard to interpret. For instance, here’s a dictionary of parameters rendered using the default IPython pretty-printer:

from penzai.example_models import simple_mlp

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([64, 1264, 128, 128, 64]),
    jax.random.key(1)
)

param_dict = {
    param.name: param.value.data_array
    for param in pz.select(mlp).at_instances_of(pz.nn.Parameter).get_sequence()
}
param_dict

Here’s what it looks like if you print it out using the built-in Python repr:

print(repr(param_dict))

And here’s how it looks in treescope, which is defined in penzai.treescope and aliased to pz.ts for easier use:

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  pz.ts.display(param_dict)

Treescope renders this object as a syntax-highlighted, color-coded structure that can be interactively folded and unfolded.

(Try clicking any marker to expand a level of the tree, or any marker to collapse a level.)

In fact, you can even look at the whole model this way, and get a color-coded view of all the parts of your model:

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  pz.ts.display(mlp)

Let’s register treescope as the default pretty-printer for IPython. This is the recommended way to use treescope in an interactive setting. Treescope is designed to be a drop-in replacement for the ordinary IPython pretty-printer, so you should be able to start using it right away.

pz.ts.register_as_default()

Foldable and unfoldable nested objects#

Treescope lets you expand and collapse any level of your tree, so you can look at the parts you care about. In treescope, you can collapse or expand any object that would render as multiple lines (even if treescope doesn’t recognize the type!)

import dataclasses

@dataclasses.dataclass
class MyDataclass:
  a: Any
  b: Any
  c: Any

class TheZenOfPython:
  def __repr__(self):
    return "<The Zen of Python:\nBeautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated.\nFlat is better than nested.\nSparse is better than dense.\nReadability counts.\nSpecial cases aren't special enough to break the rules.\nAlthough practicality beats purity.\nErrors should never pass silently.\nUnless explicitly silenced.\nIn the face of ambiguity, refuse the temptation to guess.\nThere should be one-- and preferably only one --obvious way to do it.\nAlthough that way may not be obvious at first unless you're Dutch.\nNow is better than never.\nAlthough never is often better than *right* now.\nIf the implementation is hard to explain, it's a bad idea.\nIf the implementation is easy to explain, it may be a good idea.\nNamespaces are one honking great idea -- let's do more of those!>"
[
    MyDataclass('a' * i, 'b' * i, ('cccc\n') * i)
    for i in range(10)
] + [
    MyDataclass(TheZenOfPython(), TheZenOfPython(), TheZenOfPython())
]

Copyable key paths#

Want to pull out an object deep inside a tree? You can click the icon next to any subtree to copy a function that accesses that subtree, as a Python source-code lambda expression. You can then paste it into a code cell and pass the original parameter dict to pull out the subtree you wanted.

Try it on one of the parameters of the Penzai model below! (If you run this notebook yourself, you should be able to copy paths with one click. If you are viewing this notebook on Colab without running it, you’ll need to click and then copy the path manually due to Colab’s security restrictions.)

mlp
# for example
(lambda root: root.sublayers[6].sublayers[0].weights.value)(mlp)

Structural color-coding for pz.Struct and pz.Layer#

Penzai’s base class for PyTree dataclasses includes customizable hooks for treescope rendering. In particular, neural network layers are color-coded by type, so you can see at a glance which parts of your model object are the same type.

By default, any block subclass that defines __call__ gets a randomly-selected color based on the type name:

@pz.pytree_dataclass
class MyShiftLayer(pz.Layer):
  shift: float

  def __call__(self, value):
    return value + self.shift

@pz.pytree_dataclass
class MyNoOpLayer(pz.Layer):
  def __call__(self, value):
    return value
pz.nn.Sequential([
    MyShiftLayer(1),
    MyShiftLayer(2),
    MyNoOpLayer(),
    MyShiftLayer(3),
])

But you can customize the color by overriding treescope_color:

@pz.pytree_dataclass
class MyFancyShiftLayer(pz.Layer):
  name: str = dataclasses.field(metadata={"pytree_node": False})
  shift: float

  def __call__(self, value):
    return value + self.shift

  def treescope_color(self):
    return pz.color_from_string(str(self.name))
pz.nn.Sequential([
    MyFancyShiftLayer("foo", 1),
    MyFancyShiftLayer("bar", 2),
    MyNoOpLayer(),
    MyFancyShiftLayer("foo", 3),
])
@pz.pytree_dataclass
class MyObject(pz.Struct):
  a: float
  b: float

  def treescope_color(self):
    return "cyan"

MyObject(a=1.0, b=2.0)

This is used throughout Penzai to provide color-coded representations that emphasize the behavior of individual parts of complex models.

(There are also some Penzai-specific renderers for understanding dataflow penzai models in particular, which are described in the other tutorial notebooks.)

Copyable code and roundtrip mode#

Documentation for Python’s repr says:

For many types, this function makes an attempt to return a string that would yield an object with the same value when passed to eval(); otherwise, the representation is a string enclosed in angle brackets that contains the name of the type of the object together with additional information often including the name and address of the object

Treescope follows this principle for everything it renders. Almost all of the output of treescope is valid Python syntax, and any extra annotations are either hidden from selection or represented as Python comments.

For instance, we’ll show again the example MLP model:

mlp

This printout is for the most part executable Python code, specifying the types and fields for each value. However, the individual types may not be in scope, and some values (like jax.Array) cannot be directly rebuilt from their repr.

You can fix this by running treescope in “roundtrip mode”, which

  • adds qualified names to all types

  • disables non-roundtrippable summaries for special-cased types (like NamedArray)

  • wraps any non-rountrippable type (like jax.Array or any type treescope doesn’t know how to render) with a weak reference, enabling it to be copy-pasted within the current interpreter session (as long as the original object isn’t garbage collected)

To toggle roundtrip mode, click on any output of treescope and press the “r” key. (Try it above!) Alternatively, pass roundtrip_mode=True to the renderer:

pz.ts.display(mlp, roundtrip_mode=True)

In roundtrip mode, as long as you’ve imported the necessary top-level modules, you should be able to select any part of the output, copy it, and paste it into another cell to rebuild an equivalent subtree to the one you copied.

Function reflection and canonical aliases#

Treescope has support for rendering useful information about functions and closures. The repr for functions isn’t always very helpful, especially if wrapped by JAX:

repr(jax.nn.relu)

Treescope tries to figure out where functions, function-like objects, and other constants are defined, and uses that to summarize them when collapsed. This works for ordinary function definitions defined anywhere and also for function-like objects in the JAX public API (see well_known_aliases.py)

jax.nn.relu

For ordinary functions, it can even identify the file where the function was defined:

jnp.sum

This works even for locally-defined notebook functions:

def my_function():
  print("hello world!")
my_function

Embedded NDArray visualizer (arrayviz) and customizable figure inlining#

Treescope includes a custom interactive NDArray visualizer designed to visualize the elements of high-dimensional arrays:

arr = (
    np.linspace(-10, 10, 20)
    * np.linspace(-10, 10, 15)[:, np.newaxis]
    * np.linspace(-1, 1, 5)[:, np.newaxis, np.newaxis]
)
pz.ts.render_array(arr)

It’s integrated with the rest of treescope, making it possible to directly visualize entire nested containers of arrays at once. (Large arrays get automatically truncated along one or more axes to keep the visualization relatively small.)

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  pz.ts.display((lambda root: root.sublayers[6].sublayers[0].weights.value)(mlp))
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  # Visualizations with many items get collapsed by default.
  # Click to expand them.
  pz.ts.display(mlp)

If you want more control over how arrays and other objects are visualized, you can write your own visualization function and configure treescope to use it:

import plotly.express as px

def visualize_with_histograms(value, path):
  if isinstance(value, (np.ndarray, jax.Array)):
    # You can use any rich display object, for instance a plotly figure:
    return pz.ts.IPythonVisualization(
        px.histogram(
            value.flatten(),
            width=400, height=200
        ).update_layout(
            margin=dict(l=20, r=20, t=20, b=20)
        )
    )
with pz.ts.active_autovisualizer.set_scoped(visualize_with_histograms):
  pz.ts.display((lambda root: root.sublayers[6].sublayers[0].weights.value)(mlp))

You can use the %%autovisualize IPython magic to enable automatic visualization in a cell:

pz.ts.register_autovisualize_magic()
%%autovisualize
(lambda root: root.sublayers[6].sublayers[0].sublayers[0].weights.value)(mlp)

See the separate array visualization tutorial for more info on how to visualize arrays and customize layouts!

Where you can use treescope#

In IPython / Colab#

Treescope works great in IPython and Colab notebooks, and is designed as a drop-in replacement for the IPython pretty-printer.

We’ve already done it above, but you can configure treescope as the default IPython formatter by calling

pz.ts.register_as_default()

or manually display specific objects with

pz.ts.display(["some object"])

There’s also a helper function to show rich objects with syntax similar to Python’s print:

pz.show("A value:", ["some object"])

If you register treescope as the default IPython formatter, you can also just do

["some object"]

In the IPython / Colab debugger#

It’s actually also possible to use treescope inside the Colab debugger or IPython’s ipdb. This isn’t specific to treescope, and it may be somewhat fragile, but it’s possible to monkey-patch pdb so that it uses IPython.display.display to display values that are output to the console. If you’ve registered treescope with IPython, that means you get access to all of it’s formatting while inspecting stack frames!

There’s an experimental wrapper that sets this up for you:

from penzai.toolshed import patch_ipdb
patch_ipdb.patch_ipdb()

Try running this and dropping into the debugger:

def my_function(some_input):
  assert some_input is None

# # Uncomment me:
# my_function({"a": 1, "b": np.arange(1000)})

In ordinary Python for offline viewing#

Treescope can render directly to static HTML, without requiring any dynamic communication between the Python kernel and the HTML renderer. This means you can directly save the output of a treescope rendering to an HTML file, and open it later to view whatever was formatted.

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  contents = pz.ts.render_to_html(mlp, roundtrip_mode=True)

with open("/tmp/treescope_output.html", "w") as f:
  f.write(contents)

# Uncomment to download the file:
# import google.colab.files
# google.colab.files.download("/tmp/treescope_output.html")

Things treescope can render#

Treescope has support for a large number of common Python objects.

Dicts, lists, tuples, and sets#

[
    [(), (1,), (1, 2, 3)],
    {"foo": "bar", "baz": "qux"},
    {(1,2,3):(4,5,6), (7,8,9):(10,11,12)},
    {"a", "b", "c", "d"}
]

Builtins and literals#

(with special handling for multiline strings)

[
    [1, 2, 3, 4],
    ["a", "b", "c", "d"],
    [True, False, None, NotImplemented, Ellipsis],
    ["a\n  multiline\n    string"]
]

Dataclasses and namedtuples#

class Foo(typing.NamedTuple):
  a: int
  b: str

Foo(a=1, b="bar")
@dataclasses.dataclass(frozen=True)
class Bar:
  c: str
  d: int
  some_list: list = dataclasses.field(default_factory=list)

IPython.display.display(Bar(c="bar", d=2))

In roundtrip mode, treescope will even help you rebuild dataclasses with weird __init__ methods:

@dataclasses.dataclass
class WeirdInitClass:
  foo: int

  def __init__(self, half_foo: int):
    self.foo = 2 * half_foo

# This shows as WeirdInitClass(foo=4):
pz.ts.display(WeirdInitClass(2))

# But in roundtrip mode (explicit or after pressing `r`), it shows as
#   pz.dataclass_from_attributes(WeirdInitClass, foo=4)
# which bypasses __init__ and rebuilds the dataclass's attributes directly,
# since __init__ doesn't take `foo` as an argument.
pz.ts.display(WeirdInitClass(2), roundtrip_mode=True)

NDArrays and NamedArrays#

Treescope summarizes the shape, mean, standard deviation, bounds, and number of special values in any arrays. It also supports automatic visualization (as described above).

[
    jnp.arange(1000),
    np.array([[np.nan] * 100, [0] * 50 + [1] * 50]),
    pz.nx.arange("foo", 10) * pz.nx.arange("bar", 15),
]
%%autovisualize
[
    jnp.arange(1000),
    np.array([[np.nan] * 100, [0] * 50 + [1] * 50]),
    pz.nx.arange("foo", 10) * pz.nx.arange("bar", 15),
]

When used in IPython, Treescope will try to render the tree structure first and then insert array visualizations later. This can make visualization faster and can sometimes let you see the shape of JAX arrays before JAX has finished computing their values.

pz.Structs, layers, and models#

pz.Struct and pz.Layer types are dataclasses, and they render similarly to ordinary dataclasses, but with a few extra features:

  • Layers are color-coded by type, and other blocks can opt-in to color-coding by defining treescope_color.

  • Layers also print a summary of their input and output structures, if known.

  • Complex models built using penzai.data_effects have extra annotations for tracking the effects and their handlers.

initialized_mlp = pz.nn.initialize_parameters(
  simple_mlp.MLP.from_config(
    feature_sizes=[32, 64, 64, 16]
  ), jax.random.key(0)
)
initialized_mlp

Functions#

(As discussed in the features section)

[
    jnp.sum,
    dataclasses.dataclass,
    lambda x: x + 2,
    jax.vmap(lambda x: x),
]

Arbitrary PyTree types#

Treescope uses a fallback rendering strategy to show the children of any PyTree type registered with JAX, even if it isn’t usually supported by treescope.

jax.tree_util.Partial(lambda x, y, z: x + y, 10, y=100)

Partial support: Repeated Python object references#

Treescope will warn you if it sees multiple references to the same mutable object, since that can cause unexpected behavior. (In this case, copying the output won’t copy the shared reference structure.)

my_shared_list = []

{
    "foo": my_shared_list,
    "bar": my_shared_list,
    "baz": [1, 2, my_shared_list]
}