Selection#
- class penzai.core.selectors.Selection[source]#
Bases:
Generic
[SelectedSubtree
],Struct
A selected subset of nodes within a larger PyTree.
Penzai selectors (such as
.at(...)
) returnSelection
objects, which indicate a specific subset of nodes within a larger PyTree, allowing those nodes to be pulled out and modified in a functional way.Selected nodes are required to be non-overlapping: no selected node can be the ancestor of any other selected node in the same selection.
For convenience, a
Selection
is also a PyTree, and its leaves are the same as the leaves in the original PyTree, but they are likely to be in a different order.Note
Aside for functional programming geeks: A
Selection
is conceptually related to an “optic”, specifically a “lens”. If you’re familiar with optics, you can think of aSelection
as a partially-applied lens: it allows either retrieving the selected values, or setting the selected values in the structure. (If you’re not familiar with optics, you can ignore this.)- Variables:
selected_by_path (collections.OrderedDict[KeyPath, SelectedSubtree]) – A mapping whose values are the selected parts from the original structure, and whose keys are the keypaths for those parts (as registered with JAX’s PyTree registry). This is an
OrderedDict
to prevent JAX from trying to sort the keys, which may be arbitrarily hashable objects without an ordering.remainder (Any) – The rest of the structure. The locations where the selected components were are marked with
SelectionHole
nodes. If the remainder also includes aSelection
itself, the remainder may also includeSelectionQuote
nodes.
Methods
__init__
(selected_by_path, remainder)apply
()Replaces each selected node with the result of applying this function.
Replaces selected list/tuple items with a sequence of new items.
apply_with_selected_index
(fn[, keep_selected])Applies a function, passing both the selected nodes and their indices.
assert_count_is
(count)Checks that an expected number of nodes is selected.
at
(accessor_fn[, multiple])Selects a specific child of each selected node.
Selects all PyTree nodes with no children, including PyTree leaves.
Selects all direct children of each selected subtree.
at_equal_to
(template)Selects subtrees that are equal to a particular object.
at_instances_of
(cls[, innermost])Selects subtrees that are an instance of the given type.
at_keypaths
(keypaths)Selects nodes by their keypaths (relative to the current selection).
Selects all PyTree leaves of each selected subtree.
Selects subtrees of selected nodes where a function evaluates to True.
count
()Returns the number of elements in the selection.
deselect
()Rebuilds the tree, forgetting which nodes were selected.
Flattens a selection whose selected values are all selections.
get
()Returns the result of a singleton selection.
Retrieves the selected subtree(s) based on their path(s).
Returns the collection of selected key paths.
Gets the selected subtree(s) in order.
insert_after
(value[, and_select])Inserts copies of
value
after each selected node.insert_before
(value[, and_select])Inserts copies of
value
before each selected node.invert
()Inverts a selection, selecting subtrees with no selected children.
is_empty
()Returns True if the selection is empty.
partition
([at_leaves])Partitions the tree into
(selected_tree, remainder_tree)
parts.Filters a selection to only the nth selected node.
refine
(selector_fn)Refines a selection by selecting within each selected subtree.
Removes selected nodes from their parent sequence (a list or tuple).
select_and_set_by_path
(replacements_by_path)Selects subtrees and replaces them based on relative keypaths.
set
(replacement)Replaces the selected subtree(s) with a fixed replacement.
set_by_path
(replacer)Replaces the selected subtree(s) based on their path(s).
set_sequence
(replacements)Replaces the selected subtree(s) in order.
show_selection
([ignore_exceptions])Renders the selection in IPython.
show_value
([ignore_exceptions])Renders the original tree in IPython, expanding up to the selected nodes.
where
()Filters to only a subset of selected nodes based on a condition.
Attributes
selected_by_path
remainder
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- __treescope_root_repr__()[source]#
Renders this selection as the root object in a treescope rendering.
- apply(fn: Callable[[SelectedSubtree], Any], *, keep_selected: Literal[False] = False, with_keypath: Literal[False] = False) Any [source]#
- apply(fn: Callable[[SelectedSubtree], OtherSubtree], *, keep_selected: Literal[True], with_keypath: Literal[False] = False) Selection[OtherSubtree]
- apply(fn: Callable[[tuple[Any, ...], SelectedSubtree], Any], *, with_keypath: Literal[True], keep_selected: Literal[False] = False) Any
- apply(fn: Callable[[tuple[Any, ...], SelectedSubtree], OtherSubtree], *, with_keypath: Literal[True], keep_selected: Literal[True]) Selection[OtherSubtree]
Replaces each selected node with the result of applying this function.
- Parameters:
fn – Function to apply to each selected node. This function should take a PyTree (if with_keypath=False) or a KeyPath and a PyTree (if with_keypath=True) and return a replacement PyTree.
with_keypath – Whether to pass the keypath as the first argument to the callable.
keep_selected – Whether to keep the nodes selected. If True, returns the modified selection; if False, rebuilds the tree after replacing.
- Returns:
Either a modified
Selection
(if keep_selected=True) or a rebuilt version of the original tree, each with the replacements applied.
- apply_and_inline(fn: Callable[[SelectedSubtree], Iterable[Any]], *, with_keypath: Literal[False] = False) Any [source]#
- apply_and_inline(fn: Callable[[tuple[Any, ...], SelectedSubtree], Iterable[Any]], *, with_keypath: Literal[True]) Any
Replaces selected list/tuple items with a sequence of new items.
This function should only be called when the selected elements are all children of lists or tuples. It removes the original selected element from the sequence, then inserts the results of calling
fn
at that location. This is similar to the “flatmap” operation in some other languages.- Parameters:
fn – Function to apply to each selected node. This function should take a PyTree (if
with_keypath=False
) or a KeyPath and a PyTree (ifwith_keypath=True
). It should return a replacement iterable of PyTrees, each of which will be inserted into the parent container in place of the selected node.with_keypath – Whether to pass the keypath as the first argument to the callable.
- Returns:
A replacement of the original tree with the selected elements replaced with the outputs of
fn
.- Raises:
ValueError – If any selected node is not the child of a list or tuple.
- apply_with_selected_index(fn: Callable[[int, SelectedSubtree], Any], keep_selected: bool = False) Any | Selection [source]#
Applies a function, passing both the selected nodes and their indices.
Indices are taken relative to the linearized sequence of currently selected nodes. In other words, if there are five nodes selected, then
fn
will be called with the numbers 0 through 4 inclusive, regardless of the specific keypaths to the nodes.- Parameters:
fn – Function to call. This function will be passed both the index of the selected node and the value, and should return something to replace the value with.
keep_selected – Whether to keep the node selected after the transformation.
- Returns:
The tree or modified selection after inserting the results of
fn
, depending onkeep_selected
.
- assert_count_is(count: int) Selection [source]#
Checks that an expected number of nodes is selected.
- Parameters:
count – The expected number of nodes.
- Returns:
The original selection unchanged.
- Raises:
AssertionError – If the selection does not have this many nodes.
- at(accessor_fn: Callable[[SelectedSubtree], Any | Collection[Any]], multiple: bool | None = None) Selection [source]#
Selects a specific child of each selected node.
Selection.at(...)
allows you to modify a tree with an almost-imperative style while maintaining a functional interface, similar to theArray.at[...]
syntax for ordinary NDArrays. It takes a callable that picks out a subtree of the tree, and returns a new selection that selects the part that was picked out.For instance, if you have an object
- ::
obj = Foo(bar=[1, 2, {“baz”: 5}])
you could select the 5 using
pz.select(obj).at(lambda x: x.bar[2]["baz"])
- Parameters:
accessor_fn – A function which takes each element of the current selection and returns a single node within that selection (if
multiple
is False) or a collection of nodes (ifmultiple
is True). This function must be structural; it must depend only on the PyTree structure of its input and not on the actual values or Python IDs of the leaves. It will be called with a copy of the object where every PyTree leaf and every empty PyTree node (e.g. an empty tuple or the None singleton) are wrapped with an internal wrapper object.multiple – Whether
accessor_fn
returns a collection of nodes to select, rather than a single node. IfNone
, first tries to find it as a single node, and if that fails, tries to find it as a collection of nodes but emits a warning.
- Returns:
A modified selection that selects the specific child of each node in the original selection (or the set of nodes if
multiple
was True).
- at_childless() Selection [source]#
Selects all PyTree nodes with no children, including PyTree leaves.
This is different than
at_pytree_leaves
in that it additionally selects pytree nodes that are childless, e.g. empty lists,None
, and structures without any PyTree children. Those nodes are not considered leaves according to JAX, but it may still be useful to select them, e.g. for visualization purposes.- Returns:
A new selection that selects every childless node of each selected subtree.
- at_children() Selection [source]#
Selects all direct children of each selected subtree.
This can be used to implement recursive tree traversals in a generic way, using something like:
def traverse(subtree): # ... process the subtree before recursive call ... subtree = select(subtree).at_children().apply(traverse) # ... process the subtree after the recursive call ... return subtree new_value = traverse(value)
- Returns:
A new selection that selects every direct child of a selected subtree. If any leaves were previously selected, those leaves will no longer be selected (since they have no children).
- at_equal_to(template: OtherSubtree) Selection[OtherSubtree] [source]#
Selects subtrees that are equal to a particular object.
Mostly a convenience wrapper for
.at_subtrees_where(lambda subtree: template == subtree)
but also skips
jax.Array
,np.ndarray
, andpenzai.core.named_axes.NamedArray
, since they override==
to return arrays.- Parameters:
template – The object to select occurrences of.
- Returns:
A refined selection that selects instances of this class that compare equal to this object (with other on the left).
- at_instances_of(cls: type[OtherSubtree] | tuple[type[OtherSubtree], ...], innermost: bool = False) Selection[OtherSubtree] [source]#
Selects subtrees that are an instance of the given type.
Convenience wrapper for:
.at_subtrees_where(lambda subtree: isinstance(subtree, cls))
- Parameters:
cls – The class (or tuple of classes) to retrieve instances of.
innermost – Whether to return the innermost instances of the class (instead of the outermost).
- Returns:
A refined selection that selects instances of this class within the original selection. If instances of this class are nested, only selects the outermost (if
innermost=False
) or the innermost (ifinnermost=True
), but never both.
- at_keypaths(keypaths: Collection[KeyPath]) Selection [source]#
Selects nodes by their keypaths (relative to the current selection).
- Parameters:
keypaths – A collection of keypaths.
- Returns:
A new selection where any node whose keypath is in the given selection is selected. Note that if any path in
keypaths
is a prefix of another, only the shorter prefix will be used, since selected nodes cannot be nested.
- at_pytree_leaves() Selection [source]#
Selects all PyTree leaves of each selected subtree.
This selects all of the leaves of the PyTree according to
jax.tree_util
, giving the most-specific selection expressible with aSelection
object. (Note that, if any objects in the tree are not registered as JAX PyTree nodes, they will be selected in their entirety even if they contain children when printed out by treescope.)- Returns:
A new selection that selects every leaf of each selected subtree.
- at_subtrees_where(filter_fn: Callable[[SelectedSubtree], bool], *, with_keypath: Literal[False] = False, absolute_keypath: bool = False, innermost: bool = False) Selection [source]#
- at_subtrees_where(filter_fn: Callable[[tuple[Any, ...], SelectedSubtree], bool], *, with_keypath: Literal[True], absolute_keypath: bool = False, innermost: bool = False) Selection
Selects subtrees of selected nodes where a function evaluates to True.
Note that a selection cannot contain a node that is the descendant of another selected node. If
innermost=False
, we return the outermost node, whereas ifinnermost=True
we return the innermost.If you want to apply a modification to all matches of a function, even if they are nested, you can use a pattern like
selection = select(value) while not selection.empty(): selection = selection.at_subtrees_where(foo).apply( bar, keep_selected=True) new_value = selection.deselect()
More complex modifications can also be made using a manual traversal, e.g.:
def traverse(subtree): # ... process the subtree before recursive call ... subtree = select(subtree).at_children().apply(traverse) # ... process the subtree after the recursive call ... return subtree new_value = traverse(value)
- Parameters:
filter_fn – A function determining which subtrees to select. Should be deterministic, and may be called more than once. This function should take a PyTree (if
with_keypath=False
) or a KeyPath and a PyTree (ifwith_keypath=True
).with_keypath – Whether to pass a keypath as the first argument to the callable.
absolute_keypath – Whether to pass the keypath relative to the root of the original tree (if True) or the keypath relative to the currently selected node (if False). Ignored if
with_keypath
is False.innermost – Whether to select the innermost subtree(s) for which the filter function is true, instead of the first subtrees encountered.
- Returns:
A new selection that selects the desired subtrees.
- deselect() Any [source]#
Rebuilds the tree, forgetting which nodes were selected.
- Returns:
A copy of
remainder
with the holes filled by the values inselected_by_path
. If called on an ordinary selection, this rebuilds the original tree.
- flatten_selected_selections() Selection[SelectedSubtree] [source]#
Flattens a selection whose selected values are all selections.
This function takes a selection for which all of the selected values are already selections, and merges them into a single selection that selects all of the values from each individual selection.
You can use this to build more complex selections by chaining your own logic. For instance, if you have written a function
f
that selects part of a tree, you can runselection.apply(f, keep_selected=True).flatten_selected_selections()
to “broadcast” that logic across all of the already-selected subtrees in the original selection.
See also
refine
, which allows you to express similar transformations more directly.- Returns:
A flattened selection object.
- get() SelectedSubtree [source]#
Returns the result of a singleton selection.
- Returns:
The selected subtree from this selection.
- Raises:
ValueError – If this selection does not have exactly one selected subtree.
- get_by_path() collections.OrderedDict[KeyPath, SelectedSubtree] [source]#
Retrieves the selected subtree(s) based on their path(s).
- Returns:
A dictionary of selected nodes, indexed by their path
- get_sequence() tuple[SelectedSubtree, ...] [source]#
Gets the selected subtree(s) in order.
Convenience wrapper for
.selected_by_path.values()
.- Returns:
A tuple containing the selected subtrees.
- insert_after(value: Any, and_select: bool = False) Any [source]#
Inserts copies of
value
after each selected node.- Parameters:
value – Value to insert after each selected node.
and_select – If True, selects the newly-inserted value for additional modification.
- Returns:
If
and_select
is False, a copy of the original tree, but withvalue
inserted after each selected node. Ifand_select
is True, a new selection selecting the inserted nodes.- Raises:
ValueError – If any selected nodes are not children of a list or tuple.
- insert_before(value: Any, and_select: bool = False) Any [source]#
Inserts copies of
value
before each selected node.- Parameters:
value – Value to insert before each selected node.
and_select – If True, selects the newly-inserted value for additional modification.
- Returns:
If
and_select
is False, a copy of the original tree, but withvalue
inserted before each selected node. Ifand_select
is True, a new selection selecting the inserted nodes.- Raises:
ValueError – If any selected nodes are not children of a list or tuple.
- invert() Selection [source]#
Inverts a selection, selecting subtrees with no selected children.
selection.invert()
selects the largest set of subtrees such that those subtrees do NOT contain any selected children in the original selection. In other words, it selects the common ancestors of all unselected nodes, without selecting any selected nodes.- Returns:
An inverted selection.
- partition(at_leaves: bool = False) tuple[Any, Any] [source]#
Partitions the tree into
(selected_tree, remainder_tree)
parts.This function can be used to separate out the selected components of a tree into their own separate tree, so that JAX functions and other JAX libraries can process them like ordinary PyTrees. It splits its input into two disjoint trees
(selected_tree, remainder_tree)
, whereselected_tree
only contains the leaves that were selected, andremainder_tree
only contains the remainder. The parts that were removed are identified using a sentinelpz.NotInThisPartition
object, which has no PyTree children.The main use case for
partition
is to identify subsets of models that should be treated in different ways by JAX API functions. For instance, if you want to take a gradient with respect to a specific subset of parameters, you can select those parameters, callpartition
to separate them from the rest, then calljax.grad
and useargnums
to identify the partition of interest. Similarly, if you want to donate only a subset of the state tojax.jit
, you can partition it and then use JAX’sdonate_argnums
argument tojax.jit
to identify the parts you want to donate. Inside the function, you can then usepz.combine
to rebuild the original tree.It is possible to repeatedly call
partition
to split a tree into more than two parts. In particular, you can select theremainder_tree
, target some additional nodes, and call.partition()
again, repeating this process as needed. All of the partitioned trees can then be re-combined using a single call topz.combine
.Note that
NotInThisPartition
is a PyTree node with no children, which means that partitioned trees are safe to pass through JAX transformations, and the set of leaves in the two partitioned trees together are the same as the set of leaves in the original selected tree.This function is inspired by Equinox’s
equinox.partition
, but is designed to work with Penzai’s selector system. Unlikeequinox.partition
, missing nodes are identified with thepz.NotInThisPartition
sentinel, and can replace arbitrary PyTree subtrees instead of just leaves. (Partitioning is also somewhat less important in Penzai than in Equinox because all PyTree leaves are arraylike by convention; partitioning is only necessary when different parts of the tree need special treatment e.g. forargnums
ordonate_argnums
parameters.)- Parameters:
at_leaves – Whether to do the partitioning at the leaf level, so that the returned trees have exactly the same structure. (Note that
pz.combine
is OK with entire subtrees missing, so this is not necessary, but can make the partitions easier to manipulate manually if desired.) If False, the entire selected subtrees will be replaced byNotInThisPartition
in the remainder tree.- Returns:
A tuple
(selected_tree, remainder_tree)
, where both trees have the same structure (ifat_leaves=True
) or the same prefix (ifat_leaves=False
) except thatNotInThisPartition
is used to replace parts that are in the other partition.
- pick_nth_selected(n: int | Sequence[int]) Selection [source]#
Filters a selection to only the nth selected node.
n
is taken relative to the linearized sequence of currently selected nodes, in the sense thatmy_selection.get_sequence()[n] == my_selection.pick_nth(n).get()
- Parameters:
n – The index of the selected nodes that should remain selected. If this is a sequence, all nodes in the sequence will be selected.
- Returns:
A new selection that includes only the
n``th node from the original selection. For instance, if the original selection has 5 nodes selected, setting ``n = 1
would produce a new selection with only the node at index 1 selected.
- refine(selector_fn: Callable[[Any], Selection[OtherSubtree]]) Selection[OtherSubtree] [source]#
Refines a selection by selecting within each selected subtree.
Although selectors can already be refined by making additional calls, chained calls generally treat all selected subtrees the same way. In contrast, this method allows each selected node to be processed independently. Additionally, similar to
apply
, the additional logic is free to modify the subtree as it goes.- Parameters:
selector_fn – A function that takes a selected subtree from this selection and returns a new selection object, usually a selection of some nodes in the input subtree.
- Returns:
A new selection that selects every node selected by
selector_fn
, but in the context of the original tree rather than the individual selected subtrees.
- remove_from_parent() Any [source]#
Removes selected nodes from their parent sequence (a list or tuple).
This is a convenience wrapper for
.apply_and_inline(lambda x: ())
.- Returns:
A copy of the original tree, but with all selected nodes removed from their parents.
- Raises:
ValueError – If any selected nodes are not children of a list or tuple.
- select_and_set_by_path(replacements_by_path: dict[KeyPath, Any]) Any [source]#
Selects subtrees and replaces them based on relative keypaths.
Convenience method that combines
at_keypaths
andset_by_path
.- Parameters:
replacements_by_path – A mapping from key paths to replacements. Key paths are relative to the current selected nodes.
- Returns:
A modified version of the original tree, with replacements taken from the replacer.
- set(replacement: Any) Any [source]#
Replaces the selected subtree(s) with a fixed replacement.
- Parameters:
replacement – The pytree to replace with.
- Returns:
A modified version of the original tree, with this replacement in place of any selected subtrees.
- set_by_path(replacer: Mapping[KeyPath, Any] | Callable[[KeyPath], Any]) Any [source]#
Replaces the selected subtree(s) based on their path(s).
If you need both the value and the key, see
.apply(fn, with_keypath=True)
.- Parameters:
replacer – A mapping from key paths to replacements, or a function that builds such a mapping. Passing
self.selection_by_path
will return the original tree unchanged.- Returns:
A modified version of the original tree, with replacements taken from the replacer.
- set_sequence(replacements: Iterable[Any]) Any [source]#
Replaces the selected subtree(s) in order.
- Parameters:
replacements – An iterable of PyTrees to insert at the selected locations, in order.
- Returns:
A modified version of the original tree, with replacements taken from the iterable.
- show_selection(ignore_exceptions: bool = False)[source]#
Renders the selection in IPython.
This method is intended to visualize the selection object itself, and renders boxes around the selected nodes.
This method should only be used when IPython is available.
- Parameters:
ignore_exceptions – Whether to catch errors during rendering and show a fallback for those subtrees.
- show_value(ignore_exceptions: bool = False)[source]#
Renders the original tree in IPython, expanding up to the selected nodes.
This method is intended to visualize a value but emphasizing the selected parts, where the selection is used to determine what to focus on by default but isn’t actually an object we care about.
This method should only be used when IPython is available.
- Parameters:
ignore_exceptions – Whether to catch errors during rendering and show a fallback for those subtrees.
- where(filter_fn: Callable[[SelectedSubtree], bool], *, with_keypath: Literal[False] = False) Selection[SelectedSubtree] [source]#
- where(filter_fn: Callable[[tuple[Any, ...], SelectedSubtree], bool], *, with_keypath: Literal[True]) Selection[SelectedSubtree]
Filters to only a subset of selected nodes based on a condition.
- Parameters:
filter_fn – Function to determine whether to keep a node in the selection. This function should take a PyTree (if
with_keypath=False
) or a KeyPath and a PyTree (ifwith_keypath=True
).with_keypath – Whether to pass the keypath as the first argument to the callable.
- Returns:
A new selection that includes only the selected parts where
filter_fn
evaluates to true.