Simplified V2 API (penzai.experimental.v2)#

penzai.experimental.v2 is a redesign of Penzai’s neural network system, which is intended to simplify the user experience and remove boilerplate in the original design. Eventually, this will be moved out of the “experimental” prefix and replace the original neural network components.

You can read about the V2 design in the guide “How to Think in Penzai (v2 API)”, or this document comparing the two APIs.

To use the V2 API, we suggest importing the pz alias namespace from penzai.experimental.v2.pz:

from penzai.experimental.v2 import pz

The rest of this page lists the main components used in the V2 API.

Specific to the V2 Neural Network API#

Parameters and State Variables#

The V2 API introduces stateful parameters and state variables, which simplify working with shared parameters and interventions with side effects.

pz.Parameter

Alias of penzai.experimental.v2.core.variables.Parameter: A model parameter variable.

pz.ParameterValue

Alias of penzai.experimental.v2.core.variables.ParameterValue: The value of a Parameter, as a frozen JAX pytree.

pz.ParameterSlot

Alias of penzai.experimental.v2.core.variables.ParameterSlot: A slot for a parameter in a model.

pz.StateVariable

Alias of penzai.experimental.v2.core.variables.StateVariable: A mutable state variable.

pz.StateVariableValue

Alias of penzai.experimental.v2.core.variables.StateVariableValue: The value of a StateVariable, as a frozen JAX pytree.

pz.StateVariableSlot

Alias of penzai.experimental.v2.core.variables.StateVariableSlot: A slot for a parameter in a model.

pz.unbind_variables()

Alias of penzai.experimental.v2.core.variables.unbind_variables: Unbinds variables from a pytree, inserting variable slots in their place.

pz.bind_variables(tree, variables[, ...])

Alias of penzai.experimental.v2.core.variables.bind_variables: Binds variables (mutable or frozen) into the variable slots in a pytree.

pz.freeze_variables(tree[, predicate])

Alias of penzai.experimental.v2.core.variables.freeze_variables: Replaces each variable in a pytree with a frozen copy.

pz.variable_jit(fun, *[, donate_variables])

Alias of penzai.experimental.v2.core.variables.variable_jit: Variable-aware version of jax.jit.

pz.unbind_params()

Alias of penzai.experimental.v2.core.variables.unbind_params: Version of unbind_variables that only extracts Parameters.

pz.freeze_params(tree[, predicate])

Alias of penzai.experimental.v2.core.variables.freeze_params: Version of freeze_variables that only freezes Parameters.

pz.unbind_state_vars()

Alias of penzai.experimental.v2.core.variables.unbind_state_vars: Version of unbind_variables that only extracts StateVariables.

pz.freeze_state_vars(tree[, predicate])

Alias of penzai.experimental.v2.core.variables.freeze_state_vars: Version of freeze_variables that only freezes StateVariables.

pz.VariableConflictError

Alias of penzai.experimental.v2.core.variables.VariableConflictError: Raised when a Variable label is used by multiple Variables.

pz.UnboundVariableError

Alias of penzai.experimental.v2.core.variables.UnboundVariableError: Raised when attempting to access the value of an unbound variable.

pz.VariableLabel

Alias of typing.Hashable: A generic version of collections.abc.Hashable.

pz.AbstractVariable

Alias of penzai.experimental.v2.core.variables.AbstractVariable: Base class for all variables.

pz.AbstractVariableValue

Alias of penzai.experimental.v2.core.variables.AbstractVariableValue: Base class for all frozen variables.

pz.AbstractVariableSlot

Alias of penzai.experimental.v2.core.variables.AbstractVariableSlot: Base class for all variable slots.

pz.AutoStateVarLabel

Alias of penzai.experimental.v2.core.variables.AutoStateVarLabel: A label for a StateVariable that is unique based on its Python object ID.

pz.ScopedStateVarLabel

Alias of penzai.experimental.v2.core.variables.ScopedStateVarLabel: A label for a StateVariable that is unique within some scope.

pz.scoped_auto_state_var_labels([group])

Alias of penzai.experimental.v2.core.variables.scoped_auto_state_var_labels: Context manager for using scoped auto-generated StateVariable labels.

pz.RandomStream

Alias of penzai.experimental.v2.core.random_stream.RandomStream: A stateful random stream object.

Layers and Parameter Utilities#

pz.nn.Layer

Alias of penzai.experimental.v2.nn.layer.Layer: Abstract base class for neural network layers and other 1-arg callables.

pz.nn.ParameterLike

Alias of penzai.experimental.v2.nn.parameters.ParameterLike: Protocol for a parameter-like object.

pz.nn.derive_param_key(base_rng, name)

Alias of penzai.experimental.v2.nn.parameters.derive_param_key: Derives a PRNG key for a parameter from a base key and a name.

pz.nn.make_parameter(name, init_base_rng, ...)

Alias of penzai.experimental.v2.nn.parameters.make_parameter: Makes a parameter variable (or slot) with a given name and initializer.

pz.nn.assert_no_parameter_slots(model)

Alias of penzai.experimental.v2.nn.parameters.assert_no_parameter_slots: Asserts that the given model has no ParameterSlot subtrees.

Basic Combinators#

pz.nn.Sequential

Alias of penzai.experimental.v2.nn.grouping.Sequential: A group of layers to call sequentially.

pz.nn.NamedGroup

Alias of penzai.experimental.v2.nn.grouping.NamedGroup: A layer that names an activation or a sequence of layers.

pz.nn.CheckedSequential

Alias of penzai.experimental.v2.nn.grouping.CheckedSequential: A group of layers to call sequentially, with known input/output types.

pz.nn.Residual

Alias of penzai.experimental.v2.nn.combinators.Residual: A residual block, which runs its sublayers then adds the input.

pz.nn.BranchAndAddTogether

Alias of penzai.experimental.v2.nn.combinators.BranchAndAddTogether: A data-flow branch with additive interactions between branches.

pz.nn.BranchAndMultiplyTogether

Alias of penzai.experimental.v2.nn.combinators.BranchAndMultiplyTogether: A data-flow branch with multiplicative interactions between branches.

pz.nn.inline_anonymous_sequentials(tree)

Alias of penzai.experimental.v2.nn.grouping.inline_anonymous_sequentials: Inlines instances of Sequential (not subclasses) into parent groups.

pz.nn.inline_groups(tree, parent_filter, ...)

Alias of penzai.experimental.v2.nn.grouping.inline_groups: Inlines sequential nodes into their parents if possible.

pz.nn.is_anonymous_sequential(tree)

Alias of penzai.experimental.v2.nn.grouping.is_anonymous_sequential: Checks if the type of a node is exactly Sequential, not a named subclass.

pz.nn.is_sequential_or_named(tree)

Alias of penzai.experimental.v2.nn.grouping.is_sequential_or_named: Checks if a tree is a subclass of Sequential or a NamedGroup.

Basic Operations#

pz.nn.Elementwise

Alias of penzai.experimental.v2.nn.basic_ops.Elementwise: A layer that runs an elementwise operation on its NamedArray argument.

pz.nn.Softmax

Alias of penzai.experimental.v2.nn.basic_ops.Softmax: Layer that applies a softmax along a given set of axes.

pz.nn.CheckStructure

Alias of penzai.experimental.v2.nn.grouping.CheckStructure: A layer that checks the structure of the value passing through it.

pz.nn.Identity

Alias of penzai.experimental.v2.nn.grouping.Identity: A layer that returns its input unchanged, without any side effects.

pz.nn.CastToDType

Alias of penzai.experimental.v2.nn.basic_ops.CastToDType: Casts an input to a given dtype.

Linear and Affine Layers#

pz.nn.Linear

Alias of penzai.experimental.v2.nn.linear_and_affine.Linear: A generalized linear (not affine) operator, for named arrays.

pz.nn.RenameAxes

Alias of penzai.experimental.v2.nn.linear_and_affine.RenameAxes: Convenience layer that renames axes of its input.

pz.nn.AddBias

Alias of penzai.experimental.v2.nn.linear_and_affine.AddBias: Shifts an input by a learnable offset (a bias).

pz.nn.Affine

Alias of penzai.experimental.v2.nn.linear_and_affine.Affine: Affine layer: combination of Linear and AddBias.

pz.nn.ConstantRescale

Alias of penzai.experimental.v2.nn.linear_and_affine.ConstantRescale: Applies a constant scaling factor.

pz.nn.NamedEinsum

Alias of penzai.experimental.v2.nn.linear_and_affine.NamedEinsum: An Einsum operation that contracts based on axis names.

pz.nn.LinearInPlace

Alias of penzai.experimental.v2.nn.linear_and_affine.LinearInPlace: Container for "in-place" linear operators that preserve axis names.

pz.nn.LinearOperatorWeightInitializer

Alias of penzai.experimental.v2.nn.linear_and_affine.LinearOperatorWeightInitializer: Protocol for an initializer for a general linear NamedArray weight.

pz.nn.contract(names, left, right)

Alias of penzai.experimental.v2.nn.linear_and_affine.contract: Contracts two named arrays along the given axis names.

pz.nn.variance_scaling_initializer(key, *, ...)

Alias of penzai.experimental.v2.nn.linear_and_affine.variance_scaling_initializer: Generic variance scaling initializer.

pz.nn.xavier_normal_initializer(key, *[, ...])

Generic variance scaling initializer.

pz.nn.xavier_uniform_initializer(key, *[, ...])

Generic variance scaling initializer.

pz.nn.constant_initializer(value)

Alias of penzai.experimental.v2.nn.linear_and_affine.constant_initializer: Returns an initializer that uses a constant value.

pz.nn.zero_initializer(key, *, input_axes, ...)

Zeros initializer for named arrays.

Standardization#

pz.nn.LayerNorm

Alias of penzai.experimental.v2.nn.standardization.LayerNorm: Layer normalization layer.

pz.nn.Standardize

Alias of penzai.experimental.v2.nn.standardization.Standardize: Standardization layer.

pz.nn.RMSLayerNorm

Alias of penzai.experimental.v2.nn.standardization.RMSLayerNorm: Root-mean-squared layer normalization layer.

pz.nn.RMSStandardize

Alias of penzai.experimental.v2.nn.standardization.RMSStandardize: Root-mean-squared standardization layer.

Dropout#

pz.nn.StochasticDropout

Alias of penzai.experimental.v2.nn.dropout.StochasticDropout: Stochastic dropout layer.

pz.nn.DisabledDropout

Alias of penzai.experimental.v2.nn.dropout.DisabledDropout: A no-op layer taking the place of a disabled StochasticDropout layer.

pz.nn.maybe_dropout(drop_rate[, ...])

Alias of penzai.experimental.v2.nn.dropout.maybe_dropout: Constructs either a stochastic or disabled dropout layer.

Language Modeling#

pz.nn.Attention

Alias of penzai.experimental.v2.nn.attention.Attention: A basic attention combinator.

pz.nn.KVCachingAttention

Alias of penzai.experimental.v2.nn.attention.KVCachingAttention: Key/value caching variant of Attention.

pz.nn.ApplyExplicitAttentionMask

Alias of penzai.experimental.v2.nn.attention.ApplyExplicitAttentionMask: Applies an explicit attention mask to its input logit array.

pz.nn.ApplyCausalAttentionMask

Alias of penzai.experimental.v2.nn.attention.ApplyCausalAttentionMask: Builds and applies a causal attention mask based on token positions.

pz.nn.ApplyCausalSlidingWindowAttentionMask

Alias of penzai.experimental.v2.nn.attention.ApplyCausalSlidingWindowAttentionMask: Builds and applies a sliding-window attention mask based on token positions.

pz.nn.EmbeddingTable

Alias of penzai.experimental.v2.nn.embeddings.EmbeddingTable: A table of embedding vectors for a vocabulary of tokens.

pz.nn.EmbeddingLookup

Alias of penzai.experimental.v2.nn.embeddings.EmbeddingLookup: Looks up token IDs in an embedding table.

pz.nn.EmbeddingDecode

Alias of penzai.experimental.v2.nn.embeddings.EmbeddingDecode: Uses an embedding table to map embeddings back to token scores.

pz.nn.ApplyRoPE

Alias of penzai.experimental.v2.nn.embeddings.ApplyRoPE: Adjusts input embeddings using rotary position embeddings (RoPE).

Layer Stacks#

pz.nn.LayerStack

Alias of penzai.experimental.v2.nn.layer_stack.LayerStack: A sequence of layers with identical structure, called under jax.lax.scan.

pz.nn.LayerStackVarBehavior

Alias of penzai.experimental.v2.nn.layer_stack.LayerStackVarBehavior: Behavior of a variable in a layer stack.

pz.nn.layerstack_axes_from_keypath(keypath)

Alias of penzai.experimental.v2.nn.layer_stack.layerstack_axes_from_keypath: Extracts the stacked axes from a keypath.

pz.nn.LayerStackGetAttrKey

Alias of penzai.experimental.v2.nn.layer_stack.LayerStackGetAttrKey: GetAttrKey for LayerStack with extra metadata.

Core utilities, shared with the V1 API#

Structs and Layers#

Most objects in Penzai models are subclasses of pz.Struct and decorated with pz.pytree_dataclass, which makes them into frozen Python dataclasses that are also JAX PyTrees.

pz.pytree_dataclass([cls, ...])

Alias of penzai.core.struct.pytree_dataclass: Decorator for constructing a frozen PyTree dataclass.

pz.Struct

Alias of penzai.core.struct.Struct: Base class for penzai PyTree structures.

PyTree Manipulation#

Penzai provides a number of utilities to make targeted modifications to PyTrees. Since Penzai models are PyTrees, you can use them to insert new layers into models, or modify the configuration of existing layers.

pz.select(tree)

Alias of penzai.core.selectors.select: Wraps a tree in a singleton selection for processing.

pz.Selection

Alias of penzai.core.selectors.Selection: A selected subset of nodes within a larger PyTree.

pz.combine(*partitions)

Alias of penzai.core.partitioning.combine: Combines leaves from multiple partitions.

pz.NotInThisPartition

Alias of penzai.core.partitioning.NotInThisPartition: Sentinel object that identifies subtrees removed by partition.

pz.pretty_keystr(keypath, tree)

Alias of penzai.core.tree_util.pretty_keystr: Constructs a pretty name from a keypath and an object.

Named Axes#

pz.nx is an alias for penzai.core.named_axes, which contains Penzai’s named axis system. Some commonly-used attributes on pz.nx:

pz.nx.NamedArray

Alias of penzai.core.named_axes.NamedArray: A multidimensional array with a combination of positional and named axes.

pz.nx.nmap(fun)

Alias of penzai.core.named_axes.nmap: Automatically vectorizes fun over named axes of NamedArray inputs.

pz.nx.wrap(array, *names)

Alias of penzai.core.named_axes.NamedArray.wrap: Wraps a positional array as a NamedArray.

See penzai.core.named_axes for documentation of all of the methods and classes accessible through the pz.nx alias.

To simplify slicing named axes, Penzai also provides a helper object:

pz.slice

Builds a slice when sliced (e.g. pz.slice[1:3] == slice(1, 3, None)).

Visualization#

pz.ts is an alias namespace for Penzai’s interactive pretty printer Treescope. Some commonly-used attributes on pz.ts:

pz.ts.register_as_default([...])

Alias of penzai.treescope.treescope_ipython.register_as_default: Registers treescope as the default IPython renderer.

pz.ts.register_autovisualize_magic()

Alias of penzai.treescope.treescope_ipython.register_autovisualize_magic: Registers the %%autovisualize magic.

pz.ts.render_array(array, *[, columns, ...])

Alias of penzai.treescope.arrayviz.arrayviz.render_array: Renders an array (positional or named) to a displayable HTML object.

See the documentation for pz.ts to view all of the methods and classes accessible through this alias namespace.

Penzai also provides a utility for quickly showing a value with Treescope in an IPython notebook, using syntax similar to ordinary print:

pz.show(*args[, wrap, space_separated])

Alias of penzai.treescope.treescope_ipython.show: Shows a list of objects inline, like python print, but with rich display.

Shape-Checking#

pz.chk is an alias for penzai.core.shapecheck, which contains utilities for checking the shapes of PyTrees of positional and named arrays. Some commonly-used attributes on pz.chk:

pz.chk.ArraySpec

Alias of penzai.core.shapecheck.ArraySpec: A non-leaf marker for a (named) array structure.

pz.chk.var(name)

Alias of penzai.core.shapecheck.var: Creates a variable for an axis shape.

pz.chk.vars_for_axes(var_name, ...)

Alias of penzai.core.shapecheck.vars_for_axes: Creates variables for a known collection of named axes.

See penzai.core.shapecheck for documentation of all of the methods and classes accessible through the pz.chk alias.

Context Management#

pz.disable_interactive_context()

Alias of penzai.core.context.disable_interactive_context: Clears the global interactive context stack and disables interactive mode.

pz.enable_interactive_context()

Alias of penzai.core.context.enable_interactive_context: Enables the global interactive context stack.

pz.ContextualValue

Alias of penzai.core.context.ContextualValue: A global value which can be modified in a scoped context.

Dataclass and Struct Utilities#

pz.dataclass_from_attributes(cls, **field_values)

Alias of penzai.core.dataclass_util.dataclass_from_attributes: Directly instantiates a dataclass given all of its fields.

pz.init_takes_fields(cls)

Alias of penzai.core.dataclass_util.init_takes_fields: Returns True if cls.__init__ takes exactly one argument per field.

pz.is_pytree_dataclass_type(cls)

Alias of penzai.core.struct.is_pytree_dataclass_type: Checks if a class was wrapped in the pytree_dataclass decorator.

pz.is_pytree_node_field(field)

Alias of penzai.core.struct.is_pytree_node_field: Returns True if this field is treated as a PyTree child node by Struct.

pz.StructStaticMetadata

Alias of penzai.core.struct.StructStaticMetadata: Container for a struct's static fields.

pz.PyTreeDataclassSafetyError

Alias of penzai.core.struct.PyTreeDataclassSafetyError: Error raised due to pytree dataclass safety checks.