Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Selectors and Selections#

Penzai is designed to make it easy to make targeted modifications to models and parameter trees. For this purpose, Penzai provides a powerful “selection” system to allow you to identify, visualize, and modify arbitrary JAX PyTrees. In short, pz.select enables .at[...].set(...)-style modification for arbitrary PyTree types.

This notebook describes the basics of Penzai’s selection object and shows how you can use it to make a variety of modifications.

Setup#

We’ll start by setting up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations
import typing
import traceback

import jax
import jax.numpy as jnp
import penzai
from penzai import pz
from penzai.example_models import simple_mlp

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

What are selectors?#

If you’re familiar with JAX, you probably know that JAX arrays are immutable. This maens that it’s not possible to directly assign to elements of a JAX array:

array = jnp.zeros((10, 10))
array
try:
  array[1, 2] = 42
except TypeError:
  import traceback
  traceback.print_exc()

Instead, to modify a JAX array you can use the special .at property, which returns a modified copy of the original array with changes at the location you specified:

modified_array = array.at[1, 2].set(42)
modified_array

JAX also introduces the concept of PyTrees: nested containers of data that JAX knows how to traverse. Most JAX transformations are designed to work with functions that take PyTrees as input and return PyTrees as output.

Although JAX does treat dictionaries and lists as PyTrees, JAX usually assumes that those trees will not be mutated, and expects the user to write code in a functional style, avoiding mutating its arguments directly.

pz.select enables .at[...].set(...)-style modification for arbitrary PyTrees, making it possible to perform complex path-based and type-driven modifications to the objects you’re already using:

my_nested_object = {
    "a": 1,
    "b": jnp.arange(10),
    "c": [
        {"value": jnp.arange(12)},
        {"value": jnp.zeros([7])},
        {"value": 3},
    ]
}
(
    pz.select(my_nested_object)
      .at(lambda root: root["c"])
      .at_instances_of(jax.typing.ArrayLike)
      .apply(lambda value: value + 100)
)

Selectors are first-class objects, defined in penzai.core.selectors but aliased to pz.select and pz.Selection. Internally, they are ordinary PyTrees:

my_selection = (
    pz.select(my_nested_object)
      .at(lambda root: root["c"])
      .at_instances_of(jax.typing.ArrayLike)
)

pz.show(my_selection)

By default, though, selections are rendered in a way that emphasizes what you’ve selected, so you can tell at a glance that you’re making the changes you expect to:

my_selection

(This fancy rendering mode is triggered when you display a single selection after enabling pz.ts.register_as_default(), or when you call .show_selection() on a selection object.)

Building selections#

The penzai selector API is designed around method chaining. You generally start by creating a trivial selection containing only a single object, the root object, using pz.select:

pz.select(my_nested_object)

You can also use the .select() attribute if you know your root object is a subclass of pz.Struct:

@pz.pytree_dataclass
class MyStruct(pz.Struct):
  foo: typing.Any

MyStruct(4).select()

You then call methods on the selection to refine it to a subset of the currently selected values. These methods all operate relative to the currently selected object, so if you chain them, each method’s output is restricted to the values selected by the previous method.

Selecting with a function#

If you want to select a specific part of a tree, you can use selection.at(...). The .at method takes as input a function that extracts the part you want to select, and returns a new selection that selects the extracted part:

pz.select(my_nested_object).at(lambda root: root["c"])
pz.select(my_nested_object).at(lambda root: root["c"][1]["value"])
pz.select(my_nested_object).at(lambda root: root["b"])

Later at calls in a chain are relative to the currently selected part:

(
    pz.select(my_nested_object)
    .at(lambda root: (root["c"][0], root["c"][1]))
    .at(lambda subtree: subtree["value"])
)

There are a few restrictions on the function that is used to select a subtree. In particular:

  • your function can’t depend on the actual value passed in, only on its PyTree structure,

  • your function should return a single node or tuple of nodes from the PyTree.

Internally, Selection.at is implemented using equinox.tree_at, which takes care of most of the heavy lifting.

Selecting by type#

If you want to select all subtrees with a particular type, you can use Selection.at_instances_of:

pz.select(my_nested_object).at_instances_of(int)

This selects any subtree for which isinstance(subtree, requested_type) evaluates to True.

Note that selections cannot be nested, so this only selects the outermost value with the given type:

pz.select(my_nested_object).at_instances_of(dict)

Selecting by condition#

More generally, you can select all subtrees for which a function evaluates to true:

(
    pz.select(my_nested_object)
    .at_subtrees_where(
        lambda subtree: isinstance(subtree, jax.Array) and subtree.size <= 10)
)

In fact, at_instances_of is a thin wrapper around at_subtrees_where. There’s another convenience method for finding values equal to a sentinel value:

(
    pz.select(my_nested_object)
    .at_equal_to(1)
)
(
    pz.select({"foo": "foo", "bar": [1, 2, 3]}).at_equal_to([1, 2, 3])
)

Selecting by JAX keypath#

PyTree nodes in JAX are associated with key paths identifying their location in the tree:

jax.tree_util.tree_map_with_path(lambda key, node: key, my_nested_object)

You can directly select nodes based on their JAX keypath:

pz.select(my_nested_object).at_keypaths([
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=2), jax.tree_util.DictKey(key='value')),
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=0)),
])

(You might have noticed that the fancy rendering for Selection objects is actually written in terms of at_keypaths, and exposes the keypaths for the currently selected nodes if you expand the last line!)

Selecting based on PyTree structure#

You can also build selections based on PyTree children or PyTree leaves:

pz.select(my_nested_object).at_children()
pz.select(my_nested_object).at_pytree_leaves()

Note that some PyTree nodes, like None or the empty tuple (), don’t have any children and are ignored by jax.tree_util.tree_map, so they won’t be selected by .at_pytree_leaves:

pz.select([1, 2, (), 3, None, 4]).at_pytree_leaves()

If you want to select these too, you can use .at_childless:

pz.select([1, 2, (), 3, None, 4]).at_childless()

Filtering selections#

Given an existing set of selected nodes, you can choose to filter down to only a subset that match a criterion:

pz.select(list(range(20))).at_instances_of(int).where(lambda x: x % 2 == 0)

You can also choose to keep only the nth node that you’ve selected (e.g. the third integer in the collection), zero-indexed:

pz.select([1, "a", 2, "b", 3, "c", 4, "d"]).at_instances_of(int).pick_nth_selected(2)

Inverting selections#

Selections can be inverted, which produces a new Selection which selects every PyTree node that did NOT contain any of the selected nodes:

selection = pz.select([1, "a", 2, "b", 3, "c", 4, "d"]).at_instances_of(int).pick_nth_selected(2)

print("Original:")
selection.show_selection()
print("Inverted:")
selection.invert().show_selection()
selection = pz.select(my_nested_object).at_keypaths([
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=2), jax.tree_util.DictKey(key='value')),
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=0)),
])

print("Original:")
selection.show_selection()
print("Inverted:")
selection.invert().show_selection()

Building more complex selections#

For highly dynamic selections, you can drop down to lower-level manipulation.

Selections provide a .refine method, which runs a function on each of the currently selected values, takes the Selection returned by that function, and then takes the disjoint union of them. This can be useful for dynamic input-dependent branching:

# Select field "a" of Foo and field "b" of Bar

@pz.pytree_dataclass
class Foo(pz.Struct):
  a: typing.Any
  b: typing.Any

@pz.pytree_dataclass
class Bar(pz.Struct):
  a: typing.Any
  b: typing.Any

def refine_fn(value):
  if isinstance(value, Foo):
    return value.select().at(lambda x: x.a)
  else:
    return value.select().at(lambda x: x.b)
pz.select([
    Foo(1, 2),
    Bar(3, 4),
    Foo(5, 6),
    Bar(7, 8),
    None,
]).at_instances_of((Foo, Bar)).refine(refine_fn)

Inspecting selections and retrieving values#

Once you have a selection, there are various ways to inspect them and their contents.

Looking at selections#

You can view a pretty-printed version of the selection using .show_selection() (the same pretty-printing we’ve been using throughout this notebook):

pz.select(my_nested_object).at_children().show_selection()

You can also view the selected values without the selection annotations using .show_value(). This is similar to just rendering the root object with treescope, but it automatically expands so that the nodes you selected are visible, and collapses all other nodes.

stuff = (
    [[["haystack"] * 5] * 5] * 5
    + [[["haystack"] * 5] * 5 + [["haystack", "needle", "haystack", "haystack", "haystack"]]]
    + [[["haystack"] * 5] * 5] * 5
)

pz.select(stuff).at_equal_to("needle").show_value()

Counting selected objects#

Selections have count and is_empty methods for inspecting their size:

pz.select([1, 2]).at_instances_of(str).count()
pz.select([1, 2]).at_instances_of(str).is_empty()
pz.select([1, 2]).at_instances_of(int).count()
pz.select([1, 2]).at_instances_of(int).is_empty()

If you already know how many objects should be selected, you can also add an assertion in the middle of a chain:

pz.select([1, 2]).at_instances_of(int).assert_count_is(2).apply(lambda x: x + 1)

Retrieving values#

You can extract all of the selected values using .get_sequence():

selection = (
    pz.select(my_nested_object)
    .at_subtrees_where(
        lambda subtree: isinstance(subtree, jax.Array) and subtree.size <= 10)
)

selection.get_sequence()

Or get them in a dictionary form with get_by_path() (equivalent to just accessing the .selected_by_path attribute):

selection.get_by_path()

If you know there’s exactly one value, you can just call .get()

pz.select(my_nested_object).at(lambda root: root["c"][0]).get()

You can also get the selected objects in a dictionary form by accessing the selected_by_path attribute, which is how selections store their selected nodes internally:

(
    pz.select(my_nested_object)
    .at_subtrees_where(
        lambda subtree: isinstance(subtree, jax.Array) and subtree.size <= 10
    )
).selected_by_path

Partitioning PyTrees#

You can use .partition() to split a selected object into two parts, one containing only the selected subtrees, and one containing everything else:

selected, rest = (
    pz.select(my_nested_object)
    .at_subtrees_where(
        lambda subtree: isinstance(subtree, jax.Array) and subtree.size <= 10)
).partition()
selected
rest

You can then process the two parts independently, and then recombine them into a single object using pz.combine:

pz.combine(
    jax.tree_util.tree_map(lambda x: x + 100, selected),
    rest,
)

Partitioning and combining are inspired by equinox.partition and equinox.combine. If you’re already familiar with those, the main differences are:

  • You generally don’t need to use partition and combine when running JAX transformations like jax.jit. By convention, Penzai models store all of their static metadata in dataclasses.field(metadata={"pytree_node": False}) fields which are not part of the PyTree traversal, which means partitioning isn’t as strictly necessary as it is in equinox workflows. Instead, partitioning is primarily useful if you want to apply different logic to two sets of leaves, e.g. taking a gradient only with respect to a specific subset of parameters, or defining different shardings for different array subsets.

  • Penzai uses a specific sentinel NotInThisPartition() to identify removed nodes, rather than None.

  • Penzai partitions are built to support manipulation at the subtree level, rather than at the leaf level. It’s OK to build and combine partitions even when neither partition is a strict PyTree prefix of the other, as long as the overlapping parts don’t conflict.

  • Penzai partitions are always created with selectors, rather than being built by a standalone function.

Modifying selected values#

An important feature of selections is that they allow you to perform detailed modifications to (copies of) large trees. Selections expose a number of methods for this purpose.

Replacing selected values#

The simplest modification you can make is to replace each selected subtree with another subtree or value:

pz.select(my_nested_object).at_instances_of(jax.Array).set("hello world!")

You can optionally provide different values for each selected node:

pz.select(my_nested_object).at_instances_of(jax.Array).set_sequence([f"replacement {i}" for i in range(3)])

Or provide a mapping based on the selected keypaths:

pz.select(my_nested_object).at_instances_of(jax.Array).set_by_path({
    (jax.tree_util.DictKey(key='b'),): "A",
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=0), jax.tree_util.DictKey(key='value')): "B",
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=1), jax.tree_util.DictKey(key='value')): "C",
})

set_by_path accepts the same input as is produced by get_by_path, which can be convenient for modifying some values and then putting them back, or for taking gradients with respect to only a subset of values:

selection = pz.select(my_nested_object).at_instances_of(jax.Array)
selection.set_by_path(
    {key: f"A JAX array: {value}" for key, value in selection.get_by_path().items()}
)
def my_loss(obj):
  return obj["a"] + jnp.sum(obj["b"]**2) + jnp.sum(obj["c"][0]["value"])

my_float_object = jax.tree_util.tree_map(lambda leaf: jnp.array(leaf, dtype=jnp.float32), my_nested_object)

# Take gradients w.r.t. non-scalars only, ignoring my_float_object["a"] and my_float_object["c"][2]["value"]
gradient_selection = pz.select(my_float_object).at_instances_of(jax.Array).where(lambda arr: arr.size > 1)

jax.grad(
    # Swap in the version of the values that JAX is taking gradients for:
    lambda vectors_by_path: my_loss(gradient_selection.set_by_path(vectors_by_path))
)(gradient_selection.get_by_path())

You can also use select_and_set_by_path to infer the selection from the input paths instead:

def my_loss(obj):
  return obj["a"] + jnp.sum(obj["b"]**2) + jnp.sum(obj["c"][0]["value"])

my_float_object = jax.tree_util.tree_map(lambda leaf: jnp.array(leaf, dtype=jnp.float32), my_nested_object)

# Take gradients w.r.t. non-scalars only, ignoring my_float_object["a"] and my_float_object["c"][2]["value"]
# No need to store the gradient selection itself, since it can be inferred from vectors_by_path.
vectors_by_path = pz.select(my_float_object).at_instances_of(jax.Array).where(lambda arr: arr.size > 1).get_by_path()

jax.grad(
    lambda vectors_by_path: my_loss(pz.select(my_float_object).select_and_set_by_path(vectors_by_path))
)(vectors_by_path)

Applying functions to selected values#

You can use .apply to apply a function to every selected object, similar to jax.tree_util.tree_map except that it applies to the selected subtrees rather than to the leaves:

pz.select(my_nested_object).at_instances_of(jax.Array).apply(lambda x: x**2 + 100)

You can also do apply(fn, with_keypath=True) to get access to the key paths as well:

(
    pz.select(my_nested_object)
    .at_instances_of(jax.Array)
    .apply(lambda key, value: f"key={key}, value={value}", with_keypath=True)
)

By default, applying a function also removes the selection. If you want to do further processing, you can pass keep_selected=True to replace the values but keep them selected:

(
    pz.select(my_nested_object)
    .at_instances_of(jax.Array)
    .apply(lambda key, value: f"key={key}, value={value}", with_keypath=True, keep_selected=True)
)

Sometimes, it’s also useful to pass the index of the selected node relative to the selection (e.g. “this is the third selected node”) instead of the absolute keypath. For this, theres .apply_with_selected_index:

(
    pz.select(my_nested_object)
    .at_instances_of(jax.Array)
    .apply_with_selected_index(lambda index, value: f"index={index}, value={value}", keep_selected=True)
)

Manipulating selected elements of lists and tuples#

When the selected nodes are elements of a list or tuple, there are a few other options for selector manipulation. For instance, you can insert values before or after the selected nodes:

(
    pz.select({"a": list(range(10)), "b": list(range(10, 20))})
      .at_instances_of(int)
      .where(lambda x: x % 4 == 0)
      .insert_before("before a multiple of 4")
)
(
    pz.select({"a": list(range(10)), "b": list(range(10, 20))})
      .at_instances_of(int)
      .where(lambda x: x % 4 == 0)
      .insert_after("after a multiple of 4")
)

You can also just remove the selected nodes:

(
    pz.select({"a": list(range(10)), "b": list(range(10, 20))})
      .at_instances_of(int)
      .where(lambda x: x % 4 == 0)
      .remove_from_parent()
)

If you want more control, or if you want to use the selected value to determine what to insert, you can use the method .apply_and_inline. This works like .apply, except that your function should return a sequence of values, and those values will be spliced into the original list or tuple in the same position as the original selected nodes.

(
    pz.select({"a": list(range(10)), "b": list(range(10, 20))})
      .at_instances_of(int)
      .where(lambda x: x % 4 == 0)
      .apply_and_inline(lambda x: ["before", f"the value was {x}", "after"])
)

Taking advantage of key paths#

Many of the selector functions allow you to use the PyTree path to each selected in addition to its value:

# Selecting by keypath directly
pz.select(my_nested_object).at_keypaths([
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=2), jax.tree_util.DictKey(key='value')),
    (jax.tree_util.DictKey(key='c'), jax.tree_util.SequenceKey(idx=0)),
])
# Selecting by keypath and value
(
    pz.select(my_nested_object)
    .at_subtrees_where(
        lambda path, subtree: len(path) == 3 and isinstance(subtree, jax.Array),
        with_keypath=True)
)
# Setting values by keypath
pz.select(my_nested_object).at_instances_of(jax.Array).set_by_path(lambda path: str(path))
# Setting values by keypath and original value
pz.select(my_nested_object).at_instances_of(jax.Array).apply(
    lambda path, value: str((path, value)), with_keypath=True
)

You can use jax.tree_util.keystr or pz.pretty_keystr to turn keypaths into readable strings (sometimes useful for referring to parts of a tree by name):

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config(feature_sizes=[32, 64, 64, 16]),
    jax.random.PRNGKey(0),
)
mlp
[
  key for key in
  pz.select(mlp).at_instances_of(jax.Array).selected_by_path.keys()
]
[
  jax.tree_util.keystr(key) for key in
  pz.select(mlp).at_instances_of(jax.Array).selected_by_path.keys()
]
[
  pz.pretty_keystr(key, mlp) for key in
  pz.select(mlp).at_instances_of(jax.Array).selected_by_path.keys()
]