pz: Penzai’s alias namespace#

Structs#

Most objects in Penzai models are subclasses of pz.Struct and decorated with pz.pytree_dataclass, which makes them into frozen Python dataclasses that are also JAX PyTrees.

pz.pytree_dataclass([cls, ...])

Alias of penzai.core.struct.pytree_dataclass: Decorator for constructing a frozen PyTree dataclass.

pz.Struct

Alias of penzai.core.struct.Struct: Base class for penzai PyTree structures.

PyTree Manipulation#

Penzai provides a number of utilities to make targeted modifications to PyTrees. Since Penzai models are PyTrees, you can use them to insert new layers into models, or modify the configuration of existing layers.

pz.select(tree)

Alias of penzai.core.selectors.select: Wraps a tree in a singleton selection for processing.

pz.Selection

Alias of penzai.core.selectors.Selection: A selected subset of nodes within a larger PyTree.

pz.combine(*partitions)

Alias of penzai.core.partitioning.combine: Combines leaves from multiple partitions.

pz.NotInThisPartition

Alias of penzai.core.partitioning.NotInThisPartition: Sentinel object that identifies subtrees removed by partition.

pz.pretty_keystr(keypath, tree)

Alias of penzai.core.tree_util.pretty_keystr: Constructs a pretty name from a keypath and an object.

Named Axes#

pz.nx is an alias for penzai.core.named_axes, which contains Penzai’s named axis system. Some commonly-used attributes on pz.nx:

pz.nx.NamedArray

Alias of penzai.core.named_axes.NamedArray: A multidimensional array with a combination of positional and named axes.

pz.nx.nmap(fun)

Alias of penzai.core.named_axes.nmap: Automatically vectorizes fun over named axes of NamedArray inputs.

pz.nx.wrap(array, *names)

Alias of penzai.core.named_axes.NamedArray.wrap: Wraps a positional array as a NamedArray.

See penzai.core.named_axes for documentation of all of the methods and classes accessible through the pz.nx alias.

To simplify slicing named axes, Penzai also provides a helper object:

pz.slice

Builds a slice when sliced (e.g. pz.slice[1:3] == slice(1, 3, None)).

Parameters and State Variables#

Penzai handles mutable state by embedding stateful parameters and variables into JAX pytrees. It provides a number of utilities to manipulate these stateful components and support passing them across JAX transformation boundaries.

pz.Parameter

Alias of penzai.core.variables.Parameter: A model parameter variable.

pz.ParameterValue

Alias of penzai.core.variables.ParameterValue: The value of a Parameter, as a frozen JAX pytree.

pz.ParameterSlot

Alias of penzai.core.variables.ParameterSlot: A slot for a parameter in a model.

pz.StateVariable

Alias of penzai.core.variables.StateVariable: A mutable state variable.

pz.StateVariableValue

Alias of penzai.core.variables.StateVariableValue: The value of a StateVariable, as a frozen JAX pytree.

pz.StateVariableSlot

Alias of penzai.core.variables.StateVariableSlot: A slot for a parameter in a model.

pz.unbind_variables()

Alias of penzai.core.variables.unbind_variables: Unbinds variables from a pytree, inserting variable slots in their place.

pz.bind_variables(tree, variables[, ...])

Alias of penzai.core.variables.bind_variables: Binds variables (mutable or frozen) into the variable slots in a pytree.

pz.freeze_variables(tree[, predicate])

Alias of penzai.core.variables.freeze_variables: Replaces each variable in a pytree with a frozen copy.

pz.variable_jit(fun, *[, donate_variables])

Alias of penzai.core.variables.variable_jit: Variable-aware version of jax.jit.

pz.unbind_params()

Alias of penzai.core.variables.unbind_params: Version of unbind_variables that only extracts Parameters.

pz.freeze_params(tree[, predicate])

Alias of penzai.core.variables.freeze_params: Version of freeze_variables that only freezes Parameters.

pz.unbind_state_vars()

Alias of penzai.core.variables.unbind_state_vars: Version of unbind_variables that only extracts StateVariables.

pz.freeze_state_vars(tree[, predicate])

Alias of penzai.core.variables.freeze_state_vars: Version of freeze_variables that only freezes StateVariables.

pz.VariableConflictError

Alias of penzai.core.variables.VariableConflictError: Raised when a Variable label is used by multiple Variables.

pz.UnboundVariableError

Alias of penzai.core.variables.UnboundVariableError: Raised when attempting to access the value of an unbound variable.

pz.VariableLabel

Alias of typing.Hashable: A generic version of collections.abc.Hashable.

pz.AbstractVariable

Alias of penzai.core.variables.AbstractVariable: Base class for all variables.

pz.AbstractVariableValue

Alias of penzai.core.variables.AbstractVariableValue: Base class for all frozen variables.

pz.AbstractVariableSlot

Alias of penzai.core.variables.AbstractVariableSlot: Base class for all variable slots.

pz.AutoStateVarLabel

Alias of penzai.core.variables.AutoStateVarLabel: A label for a StateVariable that is unique based on its Python object ID.

pz.ScopedStateVarLabel

Alias of penzai.core.variables.ScopedStateVarLabel: A label for a StateVariable that is unique within some scope.

pz.scoped_auto_state_var_labels([group])

Alias of penzai.core.variables.scoped_auto_state_var_labels: Context manager for using scoped auto-generated StateVariable labels.

pz.RandomStream

Alias of penzai.core.random_stream.RandomStream: A stateful random stream object.

Neural Networks#

pz.nn is an alias namespace for Penzai’s declarative neural network system, which uses a combinator-based design to expose all of your model’s operations as nodes in your model PyTree. pz.nn re-exports layers from submodules of penzai.nn in a single convenient namespace.

See the documentation for pz.nn to view all of the methods and classes accessible through this alias namespace.

Shape-Checking#

pz.chk is an alias for penzai.core.shapecheck, which contains utilities for checking the shapes of PyTrees of positional and named arrays. Some commonly-used attributes on pz.chk:

pz.chk.ArraySpec

Alias of penzai.core.shapecheck.ArraySpec: A non-leaf marker for a (named) array structure.

pz.chk.var(name)

Alias of penzai.core.shapecheck.var: Creates a variable for an axis shape.

pz.chk.vars_for_axes(var_name, ...)

Alias of penzai.core.shapecheck.vars_for_axes: Creates variables for a known collection of named axes.

See penzai.core.shapecheck for documentation of all of the methods and classes accessible through the pz.chk alias.

Dataclass and Struct Utilities#

pz.is_pytree_dataclass_type(cls)

Alias of penzai.core.struct.is_pytree_dataclass_type: Checks if a class was wrapped in the pytree_dataclass decorator.

pz.is_pytree_node_field(field)

Alias of penzai.core.struct.is_pytree_node_field: Returns True if this field is treated as a PyTree child node by Struct.

pz.StructStaticMetadata

Alias of penzai.core.struct.StructStaticMetadata: Container for a struct's static fields.

pz.PyTreeDataclassSafetyError

Alias of penzai.core.struct.PyTreeDataclassSafetyError: Error raised due to pytree dataclass safety checks.

Rendering and Global Configuration Management#

These utilities are available in the pz namespace for backwards compatibility. However, they have been moved to the separate Treescope pretty-printing package. See the Treescope documentation for more information.

pz.ts

Alias of penzai.pz.ts: Common aliases for treescope.

pz.show(*args[, wrap, space_separated])

Alias of penzai.treescope._compatibility_setup.show: Shows a list of objects inline, like python print, but with rich display.

pz.ContextualValue

Alias of treescope.context.ContextualValue: A global value which can be modified in a scoped context.

pz.oklch_color(lightness, chroma, hue[, alpha])

Alias of treescope.formatting_util.oklch_color: Constructs an OKLCH CSS color.

pz.color_from_string(key_string[, ...])

Alias of treescope.formatting_util.color_from_string: Derives a color whose hue is keyed by a string.

pz.dataclass_from_attributes(cls, **field_values)

Alias of treescope.dataclass_util.dataclass_from_attributes: Directly instantiates a dataclass given all of its fields.

pz.init_takes_fields(cls)

Alias of treescope.dataclass_util.init_takes_fields: Returns True if cls.__init__ takes exactly one argument per field.