SelectionHole#
- class penzai.core.selectors.SelectionHole[source]#
Bases:
StructA hole in a structure, taking the place of a selected subtree.
When building a selection, the nodes that are selected are moved out of the original tree for easier processing. They are replaced by a
SelectionHole, which points to the node that was here originally.A
SelectionHoleis a PyTree with no children. This ensures that the selected elements are actually “removed” from the tree from JAX’s point of view.Users should not need to create a
SelectionHoledirectly, and should instead use theselect(...)function and other selector traversals. However, you may see aSelectionHoleif inspecting the contents of aSelectionobject.Note that the
Selectionmachinery assumes that the selected PyTree nodes do not require their children to be a specific type, so that we can insertSelectionHolein arbitrary places in the tree. If a node makes strong assumptions about the types of its children, it may not be possible to select those children, since rebuilding that node with aSelectionHolemay fail.See https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization for more information on how to implement your PyTrees to avoid this problem.
- Variables:
path (KeyPath) – Keypath to this hole. Used to index back into the selected components.
Methods
__init__(path)Attributes
pathInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.