Selection#

class penzai.core.selectors.Selection[source]#

Bases: Generic[SelectedSubtree], Struct

A selected subset of nodes within a larger PyTree.

Penzai selectors (such as .at(...)) return Selection objects, which indicate a specific subset of nodes within a larger PyTree, allowing those nodes to be pulled out and modified in a functional way.

Selected nodes are required to be non-overlapping: no selected node can be the ancestor of any other selected node in the same selection.

For convenience, a Selection is also a PyTree, and its leaves are the same as the leaves in the original PyTree, but they are likely to be in a different order.

Note

Aside for functional programming geeks: A Selection is conceptually related to an “optic”, specifically a “lens”. If you’re familiar with optics, you can think of a Selection as a partially-applied lens: it allows either retrieving the selected values, or setting the selected values in the structure. (If you’re not familiar with optics, you can ignore this.)

Variables:
  • selected_by_path (collections.OrderedDict[KeyPath, SelectedSubtree]) – A mapping whose values are the selected parts from the original structure, and whose keys are the keypaths for those parts (as registered with JAX’s PyTree registry). This is an OrderedDict to prevent JAX from trying to sort the keys, which may be arbitrarily hashable objects without an ordering.

  • remainder (Any) – The rest of the structure. The locations where the selected components were are marked with SelectionHole nodes. If the remainder also includes a Selection itself, the remainder may also include SelectionQuote nodes.

Methods

__init__(selected_by_path, remainder)

apply()

Replaces each selected node with the result of applying this function.

apply_and_inline()

Replaces selected list/tuple items with a sequence of new items.

apply_with_selected_index(fn[, keep_selected])

Applies a function, passing both the selected nodes and their indices.

assert_count_is(count)

Checks that an expected number of nodes is selected.

at(accessor_fn)

Selects a specific child of each selected node.

at_childless()

Selects all PyTree nodes with no children, including PyTree leaves.

at_children()

Selects all direct children of each selected subtree.

at_equal_to(template)

Selects subtrees that are equal to a particular object.

at_instances_of(cls[, innermost])

Selects subtrees that are an instance of the given type.

at_keypaths(keypaths)

Selects nodes by their keypaths (relative to the current selection).

at_pytree_leaves()

Selects all PyTree leaves of each selected subtree.

at_subtrees_where()

Selects subtrees of selected nodes where a function evaluates to True.

count()

Returns the number of elements in the selection.

deselect()

Rebuilds the tree, forgetting which nodes were selected.

flatten_selected_selections()

Flattens a selection whose selected values are all selections.

get()

Returns the result of a singleton selection.

get_by_path()

Retrieves the selected subtree(s) based on their path(s).

get_keypaths()

Returns the collection of selected key paths.

get_sequence()

Gets the selected subtree(s) in order.

insert_after(value[, and_select])

Inserts copies of value after each selected node.

insert_before(value[, and_select])

Inserts copies of value before each selected node.

invert()

Inverts a selection, selecting subtrees with no selected children.

is_empty()

Returns True if the selection is empty.

partition([at_leaves])

Partitions the tree into (selected_tree, remainder_tree) parts.

pick_nth_selected(n)

Filters a selection to only the nth selected node.

refine(selector_fn)

Refines a selection by selecting within each selected subtree.

remove_from_parent()

Removes selected nodes from their parent sequence (a list or tuple).

select_and_set_by_path(replacements_by_path)

Selects subtrees and replaces them based on relative keypaths.

set(replacement)

Replaces the selected subtree(s) with a fixed replacement.

set_by_path(replacer)

Replaces the selected subtree(s) based on their path(s).

set_sequence(replacements)

Replaces the selected subtree(s) in order.

show_selection()

Renders the selection in IPython.

show_value()

Renders the original tree in IPython, expanding up to the selected nodes.

where()

Filters to only a subset of selected nodes based on a condition.

Attributes

selected_by_path

remainder

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__len__() int[source]#

Returns the number of elements in the selection.

apply(fn: Callable[[SelectedSubtree], Any], *, keep_selected: Literal[False] = False, with_keypath: Literal[False] = False) Any[source]#
apply(fn: Callable[[SelectedSubtree], OtherSubtree], *, keep_selected: Literal[True], with_keypath: Literal[False] = False) Selection[OtherSubtree]
apply(fn: Callable[[tuple[Any, ...], SelectedSubtree], Any], *, with_keypath: Literal[True], keep_selected: Literal[False] = False) Any
apply(fn: Callable[[tuple[Any, ...], SelectedSubtree], OtherSubtree], *, with_keypath: Literal[True], keep_selected: Literal[True]) Selection[OtherSubtree]

Replaces each selected node with the result of applying this function.

Parameters:
  • fn – Function to apply to each selected node. This function should take a PyTree (if with_keypath=False) or a KeyPath and a PyTree (if with_keypath=True) and return a replacement PyTree.

  • with_keypath – Whether to pass the keypath as the first argument to the callable.

  • keep_selected – Whether to keep the nodes selected. If True, returns the modified selection; if False, rebuilds the tree after replacing.

Returns:

Either a modified Selection (if keep_selected=True) or a rebuilt version of the original tree, each with the replacements applied.

apply_and_inline(fn: Callable[[SelectedSubtree], Iterable[Any]], *, with_keypath: Literal[False] = False) Any[source]#
apply_and_inline(fn: Callable[[tuple[Any, ...], SelectedSubtree], Iterable[Any]], *, with_keypath: Literal[True]) Any

Replaces selected list/tuple items with a sequence of new items.

This function should only be called when the selected elements are all children of lists or tuples. It removes the original selected element from the sequence, then inserts the results of calling fn at that location. This is similar to the “flatmap” operation in some other languages.

Parameters:
  • fn – Function to apply to each selected node. This function should take a PyTree (if with_keypath=False) or a KeyPath and a PyTree (if with_keypath=True). It should return a replacement iterable of PyTrees, each of which will be inserted into the parent container in place of the selected node.

  • with_keypath – Whether to pass the keypath as the first argument to the callable.

Returns:

A replacement of the original tree with the selected elements replaced with the outputs of fn.

Raises:

ValueError – If any selected node is not the child of a list or tuple.

apply_with_selected_index(fn: Callable[[int, SelectedSubtree], Any], keep_selected: bool = False) Any | Selection[source]#

Applies a function, passing both the selected nodes and their indices.

Indices are taken relative to the linearized sequence of currently selected nodes. In other words, if there are five nodes selected, then fn will be called with the numbers 0 through 4 inclusive, regardless of the specific keypaths to the nodes.

Parameters:
  • fn – Function to call. This function will be passed both the index of the selected node and the value, and should return something to replace the value with.

  • keep_selected – Whether to keep the node selected after the transformation.

Returns:

The tree or modified selection after inserting the results of fn, depending on keep_selected.

assert_count_is(count: int) Selection[source]#

Checks that an expected number of nodes is selected.

Parameters:

count – The expected number of nodes.

Returns:

The original selection unchanged.

Raises:

AssertionError – If the selection does not have this many nodes.

at(accessor_fn: Callable[[SelectedSubtree], Any | tuple[Any, ...]]) Selection[source]#

Selects a specific child of each selected node.

Selection.at(...) allows you to modify a tree with an almost-imperative style while maintaining a functional interface, similar to the Array.at[...] syntax for ordinary NDArrays. It takes a callable that picks out a subtree of the tree, and returns a new selection that selects the part that was picked out.

For instance, if you have an object

::

obj = Foo(bar=[1, 2, {“baz”: 5}])

you could select the 5 using

pz.select(obj).at(lambda x: x.bar[2]["baz"])

Selection.at is implemented using equinox.tree_at.

Parameters:

accessor_fn – A function which takes each element of the current selection and returns a node or tuple of nodes within that selection. This function must be structural; it must depend only on the PyTree structure of its input and not on the actual values of the leaves. See equinox.tree_at for the full set of requirements.

Returns:

A modified selection that selects the specific child of each node in the original selection.

at_childless() Selection[source]#

Selects all PyTree nodes with no children, including PyTree leaves.

This is different than at_pytree_leaves in that it additionally selects pytree nodes that are childless, e.g. empty lists, None, and structures without any PyTree children. Those nodes are not considered leaves according to JAX, but it may still be useful to select them, e.g. for visualization purposes.

Returns:

A new selection that selects every childless node of each selected subtree.

at_children() Selection[source]#

Selects all direct children of each selected subtree.

This can be used to implement recursive tree traversals in a generic way, using something like:

def traverse(subtree):
  # ... process the subtree before recursive call ...
  subtree = select(subtree).at_children().apply(traverse)
  # ... process the subtree after the recursive call ...
  return subtree

new_value = traverse(value)
Returns:

A new selection that selects every direct child of a selected subtree. If any leaves were previously selected, those leaves will no longer be selected (since they have no children).

at_equal_to(template: OtherSubtree) Selection[OtherSubtree][source]#

Selects subtrees that are equal to a particular object.

Mostly a convenience wrapper for

.at_subtrees_where(lambda subtree: template == subtree)

but also skips jax.Array, np.ndarray, and penzai.core.named_axes.NamedArray, since they override == to return arrays.

Parameters:

template – The object to select occurrences of.

Returns:

A refined selection that selects instances of this class that compare equal to this object (with other on the left).

at_instances_of(cls: type[OtherSubtree] | tuple[type[OtherSubtree], ...], innermost: bool = False) Selection[OtherSubtree][source]#

Selects subtrees that are an instance of the given type.

Convenience wrapper for:

.at_subtrees_where(lambda subtree: isinstance(subtree, cls))
Parameters:
  • cls – The class (or tuple of classes) to retrieve instances of.

  • innermost – Whether to return the innermost instances of the class (instead of the outermost).

Returns:

A refined selection that selects instances of this class within the original selection. If instances of this class are nested, only selects the outermost (if innermost=False) or the innermost (if innermost=True), but never both.

at_keypaths(keypaths: Collection[KeyPath]) Selection[source]#

Selects nodes by their keypaths (relative to the current selection).

Parameters:

keypaths – A collection of keypaths.

Returns:

A new selection where any node whose keypath is in the given selection is selected. Note that if any path in keypaths is a prefix of another, only the shorter prefix will be used, since selected nodes cannot be nested.

at_pytree_leaves() Selection[source]#

Selects all PyTree leaves of each selected subtree.

This selects all of the leaves of the PyTree according to jax.tree_util, giving the most-specific selection expressible with a Selection object. (Note that, if any objects in the tree are not registered as JAX PyTree nodes, they will be selected in their entirety even if they contain children when printed out by treescope.)

Returns:

A new selection that selects every leaf of each selected subtree.

at_subtrees_where(filter_fn: Callable[[SelectedSubtree], bool], *, with_keypath: Literal[False] = False, absolute_keypath: bool = False, innermost: bool = False) Selection[source]#
at_subtrees_where(filter_fn: Callable[[tuple[Any, ...], SelectedSubtree], bool], *, with_keypath: Literal[True], absolute_keypath: bool = False, innermost: bool = False) Selection

Selects subtrees of selected nodes where a function evaluates to True.

Note that a selection cannot contain a node that is the descendant of another selected node. If innermost=False, we return the outermost node, whereas if innermost=True we return the innermost.

If you want to apply a modification to all matches of a function, even if they are nested, you can use a pattern like

selection = select(value)
while not selection.empty():
  selection = selection.at_subtrees_where(foo).apply(
      bar, keep_selected=True)
new_value = selection.deselect()

More complex modifications can also be made using a manual traversal, e.g.:

def traverse(subtree):
  # ... process the subtree before recursive call ...
  subtree = select(subtree).at_children().apply(traverse)
  # ... process the subtree after the recursive call ...
  return subtree

new_value = traverse(value)
Parameters:
  • filter_fn – A function determining which subtrees to select. Should be deterministic, and may be called more than once. This function should take a PyTree (if with_keypath=False) or a KeyPath and a PyTree (if with_keypath=True).

  • with_keypath – Whether to pass a keypath as the first argument to the callable.

  • absolute_keypath – Whether to pass the keypath relative to the root of the original tree (if True) or the keypath relative to the currently selected node (if False). Ignored if with_keypath is False.

  • innermost – Whether to select the innermost subtree(s) for which the filter function is true, instead of the first subtrees encountered.

Returns:

A new selection that selects the desired subtrees.

count() int[source]#

Returns the number of elements in the selection.

deselect() Any[source]#

Rebuilds the tree, forgetting which nodes were selected.

Returns:

A copy of remainder with the holes filled by the values in selected_by_path. If called on an ordinary selection, this rebuilds the original tree.

flatten_selected_selections() Selection[SelectedSubtree][source]#

Flattens a selection whose selected values are all selections.

This function takes a selection for which all of the selected values are already selections, and merges them into a single selection that selects all of the values from each individual selection.

You can use this to build more complex selections by chaining your own logic. For instance, if you have written a function f that selects part of a tree, you can run

selection.apply(f, keep_selected=True).flatten_selected_selections()

to “broadcast” that logic across all of the already-selected subtrees in the original selection.

See also refine, which allows you to express similar transformations more directly.

Returns:

A flattened selection object.

get() SelectedSubtree[source]#

Returns the result of a singleton selection.

Returns:

The selected subtree from this selection.

Raises:

ValueError – If this selection does not have exactly one selected subtree.

get_by_path() collections.OrderedDict[KeyPath, SelectedSubtree][source]#

Retrieves the selected subtree(s) based on their path(s).

Returns:

A dictionary of selected nodes, indexed by their path

get_keypaths() tuple[KeyPath, ...][source]#

Returns the collection of selected key paths.

get_sequence() tuple[SelectedSubtree, ...][source]#

Gets the selected subtree(s) in order.

Convenience wrapper for .selected_by_path.values().

Returns:

A tuple containing the selected subtrees.

insert_after(value: Any, and_select: bool = False) Any[source]#

Inserts copies of value after each selected node.

Parameters:
  • value – Value to insert after each selected node.

  • and_select – If True, selects the newly-inserted value for additional modification.

Returns:

If and_select is False, a copy of the original tree, but with value inserted after each selected node. If and_select is True, a new selection selecting the inserted nodes.

Raises:

ValueError – If any selected nodes are not children of a list or tuple.

insert_before(value: Any, and_select: bool = False) Any[source]#

Inserts copies of value before each selected node.

Parameters:
  • value – Value to insert before each selected node.

  • and_select – If True, selects the newly-inserted value for additional modification.

Returns:

If and_select is False, a copy of the original tree, but with value inserted before each selected node. If and_select is True, a new selection selecting the inserted nodes.

Raises:

ValueError – If any selected nodes are not children of a list or tuple.

invert() Selection[source]#

Inverts a selection, selecting subtrees with no selected children.

selection.invert() selects the largest set of subtrees such that those subtrees do NOT contain any selected children in the original selection. In other words, it selects the common ancestors of all unselected nodes, without selecting any selected nodes.

Returns:

An inverted selection.

is_empty() bool[source]#

Returns True if the selection is empty.

partition(at_leaves: bool = False) tuple[Any, Any][source]#

Partitions the tree into (selected_tree, remainder_tree) parts.

This function can be used to separate out the selected components of a tree into their own separate tree, so that JAX functions and other JAX libraries can process them like ordinary PyTrees. It splits its input into two disjoint trees (selected_tree, remainder_tree), where selected_tree only contains the leaves that were selected, and remainder_tree only contains the remainder. The parts that were removed are identified using a sentinel pz.NotInThisPartition object, which has no PyTree children.

The main use case for partition is to identify subsets of models that should be treated in different ways by JAX API functions. For instance, if you want to take a gradient with respect to a specific subset of parameters, you can select those parameters, call partition to separate them from the rest, then call jax.grad and use argnums to identify the partition of interest. Similarly, if you want to donate only a subset of the state to jax.jit, you can partition it and then use JAX’s donate_argnums argument to jax.jit to identify the parts you want to donate. Inside the function, you can then use pz.combine to rebuild the original tree.

It is possible to repeatedly call partition to split a tree into more than two parts. In particular, you can select the remainder_tree, target some additional nodes, and call .partition() again, repeating this process as needed. All of the partitioned trees can then be re-combined using a single call to pz.combine.

Note that NotInThisPartition is a PyTree node with no children, which means that partitioned trees are safe to pass through JAX transformations, and the set of leaves in the two partitioned trees together are the same as the set of leaves in the original selected tree.

This function is inspired by Equinox’s equinox.partition, but is designed to work with Penzai’s selector system. Unlike equinox.partition, missing nodes are identified with the pz.NotInThisPartition sentinel, and can replace arbitrary PyTree subtrees instead of just leaves. (Partitioning is also somewhat less important in Penzai than in Equinox because all PyTree leaves are arraylike by convention; partitioning is only necessary when different parts of the tree need special treatment e.g. for argnums or donate_argnums parameters.)

Parameters:

at_leaves – Whether to do the partitioning at the leaf level, so that the returned trees have exactly the same structure. (Note that pz.combine is OK with entire subtrees missing, so this is not necessary, but can make the partitions easier to manipulate manually if desired.) If False, the entire selected subtrees will be replaced by NotInThisPartition in the remainder tree.

Returns:

A tuple (selected_tree, remainder_tree), where both trees have the same structure (if at_leaves=True) or the same prefix (if at_leaves=False) except that NotInThisPartition is used to replace parts that are in the other partition.

pick_nth_selected(n: int | Sequence[int]) Selection[source]#

Filters a selection to only the nth selected node.

n is taken relative to the linearized sequence of currently selected nodes, in the sense that

my_selection.get_sequence()[n] == my_selection.pick_nth(n).get()
Parameters:

n – The index of the selected nodes that should remain selected. If this is a sequence, all nodes in the sequence will be selected.

Returns:

A new selection that includes only the n``th node from the original selection.  For instance, if the original selection has 5 nodes selected, setting ``n = 1 would produce a new selection with only the node at index 1 selected.

refine(selector_fn: Callable[[Any], Selection[OtherSubtree]]) Selection[OtherSubtree][source]#

Refines a selection by selecting within each selected subtree.

Although selectors can already be refined by making additional calls, chained calls generally treat all selected subtrees the same way. In contrast, this method allows each selected node to be processed independently. Additionally, similar to apply, the additional logic is free to modify the subtree as it goes.

Parameters:

selector_fn – A function that takes a selected subtree from this selection and returns a new selection object, usually a selection of some nodes in the input subtree.

Returns:

A new selection that selects every node selected by selector_fn, but in the context of the original tree rather than the individual selected subtrees.

remove_from_parent() Any[source]#

Removes selected nodes from their parent sequence (a list or tuple).

This is a convenience wrapper for .apply_and_inline(lambda x: ()).

Returns:

A copy of the original tree, but with all selected nodes removed from their parents.

Raises:

ValueError – If any selected nodes are not children of a list or tuple.

select_and_set_by_path(replacements_by_path: dict[KeyPath, Any]) Any[source]#

Selects subtrees and replaces them based on relative keypaths.

Convenience method that combines at_keypaths and set_by_path.

Parameters:

replacements_by_path – A mapping from key paths to replacements. Key paths are relative to the current selected nodes.

Returns:

A modified version of the original tree, with replacements taken from the replacer.

set(replacement: Any) Any[source]#

Replaces the selected subtree(s) with a fixed replacement.

Parameters:

replacement – The pytree to replace with.

Returns:

A modified version of the original tree, with this replacement in place of any selected subtrees.

set_by_path(replacer: Mapping[KeyPath, Any] | Callable[[KeyPath], Any]) Any[source]#

Replaces the selected subtree(s) based on their path(s).

If you need both the value and the key, see .apply(fn, with_keypath=True).

Parameters:

replacer – A mapping from key paths to replacements, or a function that builds such a mapping. Passing self.selection_by_path will return the original tree unchanged.

Returns:

A modified version of the original tree, with replacements taken from the replacer.

set_sequence(replacements: Iterable[Any]) Any[source]#

Replaces the selected subtree(s) in order.

Parameters:

replacements – An iterable of PyTrees to insert at the selected locations, in order.

Returns:

A modified version of the original tree, with replacements taken from the iterable.

show_selection()[source]#

Renders the selection in IPython.

This method is intended to visualize the selection object itself, and renders boxes around the selected nodes.

This method should only be used when IPython is available.

show_value()[source]#

Renders the original tree in IPython, expanding up to the selected nodes.

This method is intended to visualize a value but emphasizing the selected parts, where the selection is used to determine what to focus on by default but isn’t actually an object we care about.

This method should only be used when IPython is available.

where(filter_fn: Callable[[SelectedSubtree], bool], *, with_keypath: Literal[False] = False) Selection[SelectedSubtree][source]#
where(filter_fn: Callable[[tuple[Any, ...], SelectedSubtree], bool], *, with_keypath: Literal[True]) Selection[SelectedSubtree]

Filters to only a subset of selected nodes based on a condition.

Parameters:
  • filter_fn – Function to determine whether to keep a node in the selection. This function should take a PyTree (if with_keypath=False) or a KeyPath and a PyTree (if with_keypath=True).

  • with_keypath – Whether to pass the keypath as the first argument to the callable.

Returns:

A new selection that includes only the selected parts where filter_fn evaluates to true.