Copyright 2024 The Penzai Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Data Effects - Scoped Side Effects and State in Penzai Models#

Penzai neural networks are conventionally written in terms of simpler combinators like pz.nn.Sequential, pz.nn.Residual, or pz.nn.BranchAndAddTogether, which run their child sublayers in order and route inputs and outputs between them. To make everything compose together, every Penzai layer takes a single input argument, which is usually the output of the previous layer in the model.

On its own, this has a few limitations. There are a number of common patterns in neural networks that cannot be expressed easily as compositions of single-input single-output functions. For instance:

  • Attention masks and positional embeddings need to know the positions of each input token in addition to receiving inputs from their previous layers.

  • Stochastic layers like Dropout need to generate fresh random numbers.

  • Key-value caching in Transformer decoders need to save keys and values into a stateful cache.

  • Saving intermediate activations requires somehow retrieving those activations from the middle of the model.

  • Since models own their parameters as attributes, two layers that share parameters need to reference the same parameter value, so that gradients are properly shared between them.

There are a few existing solutions to this, used by other JAX neural network frameworks:

  • You could have your layers be represented by mutable Python objects that store references to each other and to mutable variables, and transform these objects into pure functions when called.

    • This approach is taken by Flax and Haiku.

    • A disadvantage of this approach is that it becomes difficult to inspect or manipulate smaller parts of your model, and almost impossible to insert new logic without editing the model code.

  • You could require every module to “thread through” the necessary state and PRNG keys as arguments to each layer, and use custom __call__ logic to match up these arguments and handle parameter sharing.

    • This is the approach taken by Equinox.

    • A disadvantage of this approach is that every submodule has to know about every argument needed by any of its children. So you can’t easily insert e.g. new Dropout layers without changing every containing class to pass around random keys.

Penzai takes a different approach, building on the “structure-encodes-computation” principle of Penzai’s neural networks , and loosely inspired by effect systems in functional programming languages. The key idea is to:

  • Represent requests for state, random numbers, or other “effectful” operations as typed nodes in the model’s PyTree

  • And handle those requests using ordinary PyTree traversals, powered by pz.select.

This system makes effect handling modular, composable, and fully under your control. It’s also fully opt-in. If your model doesn’t use effects, then you don’t have to think about the system, and it can’t affect your model’s behavior at all.

This notebook explains how the system works, and shows how you can use it to flexibly pass data into and out of model layers while still making use of Penzai’s ordinary combinators and utilities. It assumes familiarity with Penzai models and pz.select; if you haven’t read it already, you might want to start with the “How to Think in Penzai” tutorial first.

Setup#

Let’s start by setting up the environment.

Imports#

To run this notebook, you need a Python environment with penzai and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import penzai
except ImportError:
  !pip install penzai[notebook]
from __future__ import annotations

from typing import Any
import traceback
import jax
import jax.numpy as jnp
import penzai
from penzai import pz

Setting up Penzai#

For this tutorial, we’ll enable Treescope (Penzai’s pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. Treescope also has special-purpose handlers that summarize effects to make it easier to understand how effectful models work.

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

Why Data Effects?#

Before explaining Penzai’s data effect system in detail, we’ll start with a motivating example problem, and show how the the data effects system emerges as a step-by-step solution to this problem. The goal is to explain the core principles behind the system in a form you could have implemented yourself.

If you’d like to dive straight in to an explanation of the system as implemented in Penzai proper, feel free to skip to the next section.

Suppose we start with the following MLP:

from penzai.example_models import simple_mlp

mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([64, 128, 128, 128, 64]),
    jax.random.key(1)
)
mlp
%%autovisualize
example_input = pz.nx.wrap(
    jnp.arange(64 * 4, dtype=jnp.float32).reshape(4, 64)
).tag("batch", "features")
mlp(example_input)

Now further suppose that we want to extract the intermediate activations of this network. How could we do that?

Luckily, Penzai is designed to make it easy to insert new logic into a network. We can easily print intermediate values, for example, by inserting new layers into the model:

@pz.pytree_dataclass
class ShowMe(pz.Layer):
  def __call__(self, x):
    pz.show("Intermediate:", x)
    return x
patched_model = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(ShowMe())
)
patched_model
%%autovisualize
patched_model(example_input)

But we want to actually save the values instead of just printing them out. A straightforward but somewhat fragile approach would be to add the values to some global list:

my_unsafe_mutable_accumulator = []

@pz.pytree_dataclass
class AddToMyUnsafeMutableAccumulator(pz.Layer):
  def __call__(self, x):
    my_unsafe_mutable_accumulator.append(x)
    return x
unsafely_patched_model = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(AddToMyUnsafeMutableAccumulator())
)
%%autovisualize
unsafely_patched_model(example_input)
%%autovisualize
my_unsafe_mutable_accumulator

We’ve suggestively used the name “unsafe” here because, while this technically works in this case if you’re careful, it’s usually not safe to do this kind of thing in JAX codebases. The reason is that JAX assumes that any function it is transforming has no side effects. If we try to run the model under JIT compilation, we’ll end up accidentally putting JAX “Tracers” into our mutable array instead of actually extracting the intermediates.

Another disadvantage to this design is that the AddToMyUnsafeMutableAccumulator is referencing a global mutable variable my_unsafe_mutable_accumulator. So if we want to compute multiple intermediate values, we’d have to clear this accumulator manually each time we wanted to use it.

We can avoid both of these problems if we replace the global mutable variable with a local one, like this:

def capture_intermediates_after_elementwise(model, example_input):
  my_local_mutable_accumulator = []

  # Locally-defined class!
  @pz.pytree_dataclass
  class AddToMyLocalMutableAccumulator(pz.Layer):
    def __call__(self, x):
      my_local_mutable_accumulator.append(x)
      return x

  locally_patched_model = (
      pz.select(model)
      .at_instances_of(pz.nn.Elementwise)
      .insert_after(AddToMyLocalMutableAccumulator())
  )
  result = locally_patched_model(example_input)
  return result, my_local_mutable_accumulator
%%autovisualize
capture_intermediates_after_elementwise(mlp, example_input)

This function takes a model and an example input, and runs the model while also saving its inputs into a local accumulator. It then returns the final output along with the list of all accumulated values.

Since the accumulator is only used inside capture_intermediates_after_elementwise, and the function doesn’t modify any external state, this function itself is pure from JAX’s perspective. This means it’s OK to JIT-compile it:

%%autovisualize
jax.jit(capture_intermediates_after_elementwise)(mlp, example_input)

A disadvantage to this, however, is that it only supports collecting intermediate values in a very specific place: immediately after the Elementwise activation functions.

To get around this, we could allow the user to specify where they want to collect intermediates themselves, without hard-coding it in this way. We could do this by using a special sentinel type that indicates where the append function is needed, and have the helper function “inject” the mutable destination for those values:

@pz.pytree_dataclass
class ReplaceMeWithAnAppendFunction(pz.Struct):
  def treescope_color(self):
    return "yellow"

@pz.pytree_dataclass
class CollectIntermediatesHere(pz.Layer):
  append_fn: ReplaceMeWithAnAppendFunction | None = ReplaceMeWithAnAppendFunction()
  def __call__(self, x):
    self.append_fn(x)
    return x
def capture_intermediates_where_requested(model, example_input):
  my_local_mutable_accumulator = []

  def _append_fn(x):
    my_local_mutable_accumulator.append(x)

  locally_patched_model = (
      pz.select(model)
      .at_instances_of(ReplaceMeWithAnAppendFunction)
      .set(_append_fn)
  )
  result = locally_patched_model(example_input)
  return result, my_local_mutable_accumulator

Now we can collect intermediates at different places:

# Three intermediates, one after each Elementwise block:
mlp_with_requests_1 = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(CollectIntermediatesHere())
)
mlp_with_requests_1
%%autovisualize
capture_intermediates_where_requested(mlp_with_requests_1, example_input)
%%autovisualize
# Four intermediates, one after each Linear block:
mlp_with_requests_2 = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Linear)
    .insert_after(CollectIntermediatesHere())
)
capture_intermediates_where_requested(mlp_with_requests_2, example_input)

Unfortunately, we need to call the model using the special capture_intermediates_where_requested helper in order to get the intermediate values out, because otherwise we try to append to a destination that doesn’t exist:

try:
  mlp_with_requests_2(example_input)
except Exception:
  traceback.print_exc()

We could avoid this by making capture_intermediates_where_requested a type of pz.Layer instead of a standalone function:

@pz.pytree_dataclass
class CaptureIntermediatesWhereRequested(pz.Layer):
  body: pz.LayerLike

  def __call__(self, x):
    my_local_mutable_accumulator = []

    def _append_fn(x):
      my_local_mutable_accumulator.append(x)

    locally_patched_body = (
        pz.select(self.body)
        .at_instances_of(ReplaceMeWithAnAppendFunction)
        .set(_append_fn)
    )
    result = locally_patched_body(example_input)
    return result, my_local_mutable_accumulator

Then we can build a version of our model that also collects side outputs:

model_with_intermediates = CaptureIntermediatesWhereRequested(
    body=(
        pz.select(mlp)
        .at_instances_of(pz.nn.Elementwise)
        .insert_after(CollectIntermediatesHere())
    )
)
model_with_intermediates

And we can call it directly, just like we called our original model:

model_with_intermediates(example_input)

This is the essence of the data effects system:

  • We can identify where a side effect should happen by inserting some node into the model tree with a special type (ReplaceMeWithAnAppendFunction).

  • We can then wrap the model tree in a wrapper object (CaptureIntermediatesWhereRequested) that handles the effect by

    • creating its own temporary local mutable Python variables,

    • substituting them into the model,

    • running the model,

    • and then putting together a pure result.

This means the wrapper object looks just like an ordinary Penzai model to JAX, since it has a normal PyTree structure, and produces a pure functional output without accessing global state.

Next, we’ll describe the full data effects system, which abstracts this approach into a more general pattern, and imposes a few more rules to make it easier to understand in the presence of multiple effects.

How are Data Effects Defined?#

This section explains the core building blocks of the data effects system, and the different types that you’ll need to think about when using it. In short:

  • Each effect is associated with an effect protocol that determines what functions that effect provides.

  • To request that an effect be performed in a given layer, you can add effect requests to their model tree.

  • To handle those requests, you can wrap your entire model tree with an effect handler, which will replace those effect requests with effect references that are tagged as belonging to this handler.

  • You can then call the effect handler like an ordinary Penzai layer. When it runs, it will inject effect implementations into the model, which are temporary mutable objects that allow the effectful layers to communicate with the handler. You usually don’t have to worry about this unless you are implementing your own handler or working on the Penzai core systems.

The data effect system is defined in penzai.data_effects, and the builtin effect types are aliased to pz.de for easier use.

Effect Protocols#

Effect protocols define the interface that each effect supports. These are ordinary Python protocols and are primarily used for type annotations and Python typechecking.

For instance, the protocol for the randomness effect is

class RandomEffect(Protocol):
  """Protocol for the random number generation effect."""
  def next_key(self) -> jax.Array:
    """Returns a new random key."""

and the protocol for the state effect is

class LocalStateEffect(Protocol[_T]):
  """Protocol for a local state effect."""
  def get(self) -> _T:
    """Gets the current state of the local variable."""

  def set(self, value: _T):
    """Sets the current state of the local variable."""

If you want to use an effect in a layer, you should define one of its attributes to have the appropriate effect protocol as its type annotation. For instance, a dropout layer can be defined in terms of the random effect by storing an attribute of type RandomEffect. Then, in __call__, you can call methods on that attribute according to the protocol.

@pz.pytree_dataclass
class SimpleStochasticDropoutV1(pz.Layer):
  drop_rate: float
  rng: pz.de.RandomEffect

  def __call__(self, x: pz.nx.NamedArray) -> pz.nx.NamedArray:
    # Get a key.
    key = self.rng.next_key()
    # Use it to perform the layer logic.
    mask = pz.nx.nmap(jax.random.bernoulli)(
        pz.nx.random_split(key, x.named_shape),
        p=self.drop_rate
    )
    return pz.nx.nmap(jnp.where)(mask, 0.0, x/(1 - self.drop_rate))

Effect Requests#

To identify a location where you want an effect to occur, you can add an effect request node to your model tree. Effect requests are temporary markers that will eventually be replaced with concrete implementations of the effect.

Every effect request is a subclass of pz.de.EffectRequest, which is used to track whether or not effects have been handled, and also allows you to identify the effect protocol for each request:

pz.de.RandomRequest()
pz.de.RandomRequest().effect_protocol()

You will usually add effect requests at the time where you build the model. For instance, you could build one of our dropout layers using an effect request:

SimpleStochasticDropoutV1(drop_rate=0.1, rng=pz.de.RandomRequest())

Best practice: It is often useful to either configure the effect request as a default value for the attribute, or provide a class method that configures the effect request. For instance, you can do something like this:

@pz.pytree_dataclass
class SimpleStochasticDropoutV2(pz.Layer):
  drop_rate: float
  rng: pz.de.RandomEffect

  def __call__(self, x: pz.nx.NamedArray) -> pz.nx.NamedArray:
    key = self.rng.next_key()
    mask = pz.nx.nmap(jax.random.bernoulli)(
        pz.nx.random_split(key, x.named_shape),
        p=self.drop_rate
    )
    return pz.nx.nmap(jnp.where)(mask, 0.0, x/(1 - self.drop_rate))

  @classmethod
  def from_config(cls, drop_rate: float):
    return cls(drop_rate=drop_rate, rng=pz.de.RandomRequest())

This makes it easy to construct instances of your effectful layer while building a larger model:

layer = SimpleStochasticDropoutV2.from_config(drop_rate=0.1)
layer

Since a random request doesn’t actually include an implementation of the effect, you can’t call the model while it has unhandled effects:

try:
  layer(pz.nx.ones({"features": 8}))
except Exception:
  traceback.print_exc()

Before you can actually run the effect, you need to handle these requests using a handler.

Effect Handlers and Effect References#

Effect handlers are wrapper layers that take ownership of the effect references in your model, and are responsible for providing concrete implementations of those effects.

Each effect handler is a subclass of pz.de.EffectHandler, and must define two attributes: a handler_id which uniquely identifies the handler, and a body which contains the rest of your model.

You usually won’t need to provide a handler ID yourself, since it is inferred for you based on the structure of your model when you build the handler. Most handlers provide a builder classmethod for this purpose:

effectful_model = pz.nn.initialize_parameters(
    pz.nn.Sequential([
        pz.nn.add_parameter_prefix(
            "Linear_0",
            pz.nn.Linear.from_config(
                input_axes={"features": 8}, output_axes={"features": 8}
            ),
        ),
        pz.nn.Elementwise(jax.nn.relu),
        SimpleStochasticDropoutV2.from_config(drop_rate=0.1),
        pz.nn.add_parameter_prefix(
            "Linear_1",
            pz.nn.Linear.from_config(
                input_axes={"features": 8}, output_axes={"features": 8}
            ),
        ),
        pz.nn.Elementwise(jax.nn.relu),
        SimpleStochasticDropoutV2.from_config(drop_rate=0.1),
        pz.nn.add_parameter_prefix(
            "Linear_2",
            pz.nn.Linear.from_config(
                input_axes={"features": 8}, output_axes={"features": 8}
            ),
        ),
    ]),
    jax.random.key(42),
)

handled_model = pz.de.WithRandomKeyFromArg.handling(effectful_model)

When a handler is built, it finds all of the requests it can handle and swaps them out for “effect references”. These are like effect requests, but they identify the handler that is responsible for handling them. You can see them in the model tree (in this case as HandledRandomRef nodes), and Treescope links them back to the handler with the same ID:

handled_model

Effect references are subclasses of pz.de.HandledEffectRef. Each effect reference knows its own handler ID, and also defines an effect_protocol method to identify what effect it is supposed to provide.

Effect Implementations#

Effect handlers can be called like ordinary layers, and behave like pure functions without any external side effects. Depending on the handler, the structure of the input or output may need to be modified. In this case, the WithRandomKeyFromArg handler expects to be called with a tuple of two values, the first being the input to the model, and the second being a random key.

handled_model((pz.nx.ones({"features": 8}), jax.random.key(1)))

Handlers still expect to be called with a single input argument, rather than multiple arguments, so that they compose with each other and with other Penzai wrappers. For instance, you can easily jit-compile the model:

from penzai.toolshed import jit_wrapper

jitted_handled_model = jit_wrapper.Jitted(handled_model)
jitted_handled_model
jitted_handled_model((pz.nx.ones({"features": 8}), jax.random.key(1)))

Internally, when called, effect handlers are responsible for substituting all of the effect requests they own for effect implementations. Effect implementations are always subclasses of pz.de.EffectRuntimeImpl, and they are NOT usually JAX PyTree nodes. This is because they usually either have mutable attributes or include references to some external state that isn’t safe to manipulate across JAX transformation boundaries.

You shouldn’t need to think about effect implementations unless you are implementing an effect or higher-order model wrapper. But if you do run across them, they can usually be inspected and manipulated using ordinary Penzai tooling.

Here’s a contrived layer wrapper that lets you see this process in action:

@pz.pytree_dataclass
class DebugShowModelStructure(pz.Layer):
  body: pz.LayerLike
  def __call__(self, x):
    pz.show("Model structure when called:", self.body)
    return self.body(x)
debug_jitted_handled_model = (
    pz.select(jitted_handled_model)
    .at(lambda root: root.body.body)
    .apply(lambda body: DebugShowModelStructure(body))
)
debug_jitted_handled_model
debug_jitted_handled_model((pz.nx.ones({"features": 8}), jax.random.key(1)))

If you look at the SimpleStochasticDropoutV2 blocks above, you’ll see that they contain temporary RandomEffectImpl objects in place of the HandledRandomRef references. These implementations hold onto a “RandomStream” object, which is a mutable helper class that generates random numbers one at a time.

Built-in Effects#

Penzai includes four basic effects: side inputs, side outputs, random streams, and local state. This section gives a brief overview of each of these effects.

Side Inputs#

Side inputs allow you to pass inputs into layers that need them without disrupting the ordinary data flow. This is useful for providing information like attention masks or token positions, which are only required by specific types of layer.

The side input protocol defines a single method ask:

class SideInputEffect(Protocol[_T]):
  """Protocol for a side input effect."""

  def ask(self) -> _T:
    """Retrieves the value for the side input."""

Side input requests are associated with a “tag” that identifies what value should be provided:

pz.nn.ApplyAttentionMask.from_config(mask_tag="attn_mask")

You can handle SideInputEffect using WithSideInputsFromInputTuple, which redirects ordinary inputs into side inputs:

pz.de.WithSideInputsFromInputTuple.handling(pz.nn.Sequential([
    pz.nn.ApplyAttentionMask.from_config(mask_tag="attn_mask"),
    # in a real network you'd have more logic here
    pz.nn.ApplyAttentionMask.from_config(mask_tag="attn_mask"),
]), tags=["attn_mask"])

This can also be used to provide multiple side inputs at once:

@pz.pytree_dataclass
class MyLayerWithSideInputs(pz.Layer):
  side_arg: pz.de.SideInputEffect[Any]
  def __call__(self, x):
    print("Got side input:", repr(self.side_arg.ask()))
    return (x, self.side_arg.ask())
unhandled_example = pz.nn.Sequential([
    MyLayerWithSideInputs(pz.de.SideInputRequest("foo")),
    MyLayerWithSideInputs(pz.de.SideInputRequest("bar")),
    MyLayerWithSideInputs(pz.de.SideInputRequest("foo")),
])
unhandled_example
handled_example = pz.de.WithSideInputsFromInputTuple.handling(
    unhandled_example, tags=["foo", "bar"]
)
handled_example
handled_example(("main input", "value for foo", "value for bar"))

You can also provide a constant value for side inputs:

handled_example_2 = pz.de.WithConstantSideInputs.handling(
    unhandled_example, {"foo": "value for foo", "bar": "value for bar"}
)
handled_example_2
handled_example_2("main input")

Side Outputs#

Side outputs allow you to produce outputs while your model runs, without threading them through the rest of the layers. This can be useful for collecting intermediate activations or auxiliary losses.

The side output protocol defines a method tell:

class SideOutputEffect(Protocol[_T]):
  """Protocol for a side output effect."""

  def tell(self, value: _T, /):
    """Writes a value to the side output."""

Side outputs are associated with a tag that identifies what type of side output they are.

mlp_with_side_outputs = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(pz.de.TellIntermediate.from_config(tag="intermediate"))
)
mlp_with_side_outputs

For convenience, SideOutputRequest implements the tell method as a no-op, so if you don’t care about the side outputs, you can still call your model.

mlp_with_side_outputs(pz.nx.ones({"features": 64}))

To handle side outputs, you wrap it with a CollectingSideOutputs handler:

mlp_with_side_outputs_handled = pz.de.CollectingSideOutputs.handling(
    mlp_with_side_outputs,
    tag="intermediate",  # <- Optional; if omitted, collects outputs for all tags
)
mlp_with_side_outputs_handled

Calling it produces a list of side outputs along with the ordinary outputs, and those side outputs also remember their original tag and location within the PyTree. This can be used to match up the side outputs with the part of the model that produced them.

mlp_with_side_outputs_handled(pz.nx.ones({"features": 64}))

Randomness#

Layers that need random numbers can do so by inserting a random effect, defined by the protocol

class RandomEffect(Protocol):
  """Protocol for the random number generation effect."""

  def next_key(self) -> jax.Array:
    """Returns a new random key."""

We’ve seen an example of this in the previous section:

effectful_model

Random number effects are usually handled using WithRandomKeyFromArg:

handled_model = pz.de.WithRandomKeyFromArg.handling(effectful_model)
handled_model
handled_model((pz.nx.ones({"features": 8}), jax.random.key(1)))

Alternatively, you can freeze the random state to a specific value to get a deterministic model:

handled_model_frozen = pz.de.WithFrozenRandomState.handling(effectful_model, jax.random.key(1))
handled_model_frozen
handled_model_frozen(pz.nx.ones({"features": 8}))

Local State#

Finally, Penzai includes a local state handler that allows your model to hold onto and update “state variables” in a functional way. The local state effect is defined as

class LocalStateEffect(Protocol[_T]):
  """Protocol for a local state effect."""

  def get(self) -> _T:
    """Gets the current state of the local variable."""

  def set(self, value: _T):
    """Sets the current state of the local variable."""

There are three different request types for a local state effect. Usually, you will use InitialLocalStateRequest when building a stateful model, which requires you to specify a state initializer function:

pz.de.InitialLocalStateRequest(
    state_initializer=lambda: pz.nx.zeros({"foo": 10, "bar": 10}),
    category="example_state"
)

For instance, you might configure a stateful layer like this:

@pz.pytree_dataclass
class ExampleLayerWithAState(pz.Layer):
  accumulator: pz.de.LocalStateEffect

  def __call__(self, x):
    self.accumulator.set(self.accumulator.get() + x)
    return x

  @classmethod
  def from_config(cls, category="example_state"):
    return cls(accumulator=pz.de.InitialLocalStateRequest(
        state_initializer=lambda: 0.0,
        category=category,
    ))
ExampleLayerWithAState.from_config("example_state")

You can also use FrozenLocalStateRequest, which requires an actual value for the state. This can be built directly, but it’s also produced by helper functions that re-insert the states into your model (discussed later).

pz.de.FrozenLocalStateRequest(
    state=pz.nx.zeros({"foo": 10, "bar": 10}),
    category="example_state"
)

Finally, if you want multiple state variables to have the same value, you can use SharedLocalStateRequest. This requires you to specify the same name for the requests. (Names are otherwise optional.)

my_shared_state_model = pz.nn.Sequential([
    ExampleLayerWithAState(accumulator=pz.de.InitialLocalStateRequest(
        state_initializer=lambda: 0.0,
        category="example_state",
        name="shared",
    )),
    ExampleLayerWithAState(accumulator=pz.de.InitialLocalStateRequest(
        state_initializer=lambda: 0.0,
        category="example_state",
    )),
    # Shared with the *first* state variable above:
    ExampleLayerWithAState(accumulator=pz.de.SharedLocalStateRequest(
        name="shared",
        category="example_state",
    )),
])
my_shared_state_model

To handle the state effect, you can use the function pz.de.handle_local_states. This unzips your model’s state requests and returns two things: a handler for your state variables, and an initial state dict:

my_test_model = pz.nn.Sequential([
    ExampleLayerWithAState.from_config("example_state"),
    pz.nn.ConstantRescale(by=2.0),
    ExampleLayerWithAState.from_config("example_state"),
])
handled_model, initial_state_dict = pz.de.handle_local_states(my_test_model, category="example_state")
handled_model
initial_state_dict

You can then call the handled model with its input and state dict to get outputs and an updated state dict:

output, new_state_dict = handled_model((10.0, initial_state_dict))
output, new_state_dict

If you have shared state variables, you need to opt-in to state sharing. This is to prevent sharing state variables by accident.

shared_handled, shared_initial_state = pz.de.handle_local_states(
    my_shared_state_model, category="example_state", state_sharing="allowed"
)
shared_handled, shared_initial_state
shared_handled((10.0, shared_initial_state))

Given a state handler and a state dict, you can use freeze_local_states to put those state variables back into the model pytree (as FrozenLocalStateRequest instances). This can be useful if you want to extract parts of a stateful model or make more complex transformations without manually manipulating the state dict.

pz.de.freeze_local_states(handled_model, new_state_dict)

Parameter Sharing as an Effect in Penzai Models#

Penzai uses the data effect system to implement parameter sharing in a flexible way. This section describes how this works and shows you how to build models that have shared parameters.

The challenge of shared state in Penzai is:

  • Since layers own their own parameters as attributes, if multiple layers need to use the same parameter, it seems like they would each need a copy of the parameter.

  • But since Penzai models are just pytrees, and shared Python object identity is ignored by JAX, we only want to include the value for the parameter once.

Penzai resolves this by using the SideInputEffect to implement parameter sharing, in combination with a helper type SharedParameterLookup. A model with shared parameters will look something like this:

model_with_shared_params = pz.de.WithConstantSideInputs.handling(
    body=pz.nn.Sequential([
        # Contrived example: Repeat the same bias twice
        pz.nn.AddBias(
            bias=pz.nn.SharedParameterLookup(
                pz.de.SideInputRequest("shared_param"),
                value_structure=pz.chk.ArraySpec(named_shape={"features": 10}),
            ),
            new_axis_names=(),
        ),
        pz.nn.AddBias(
            bias=pz.nn.SharedParameterLookup(
                pz.de.SideInputRequest("shared_param"),
                value_structure=pz.chk.ArraySpec(named_shape={"features": 10}),
            ),
            new_axis_names=(),
        ),
    ]),
    side_inputs={
        "shared_param": pz.nn.Parameter(pz.nx.ones({"features": 10}), name="AddBias-shared")
    }
)
model_with_shared_params

SharedParameterLookup acts like a Parameter, but accessing its value attribute reads the value from the side input. This means that, when the model runs, both copies of the AddBias layer use the same parameter:

model_with_shared_params(pz.nx.zeros({"features": 10}))

But there’s only one actual instance of Parameter in the tree:

pz.select(model_with_shared_params).at_instances_of(pz.nn.Parameter).get_sequence()

This is how models with shared parameters are represented when they are built. However, Penzai also includes some helpers to make it easier to set up this parameter sharing: mark_shareable and attach_shared_parameters. These just identify which parameters need to be shared and set up the correct side input handler for you:

shareable_bias = pz.nn.mark_shareable(
    pz.nn.AddBias.from_config(biased_axes={"features": 10})
)
shareable_bias
model_def = pz.nn.attach_shared_parameters(pz.nn.Sequential([
    shareable_bias,
    shareable_bias,
]))
model_def
pz.nn.initialize_parameters(model_def, jax.random.key(123))

Composing and Patching Effects#

Since each handler has its own ID, and each effect has its own request and reference types, it’s straightforward to combine effects with each other. And since the data effects system is entirely encoded inside your model’s PyTree structure, it is easy to patch models that use effects, and sometimes even to insert new effects!

Combining Multiple Effects#

Since each effect is independent, and each handler is an ordinary single-input single-output layer, you are free to combine multiple effects in the same model:

multi_effect_model = pz.nn.Sequential([
    ExampleLayerWithAState.from_config(category="example_state"),
    pz.de.TellIntermediate.from_config(tag="intermediate"),
    MyLayerWithSideInputs(pz.de.SideInputRequest("foo")),
])
multi_effect_model
handled_multi_effect_model, initial_state = pz.de.handle_local_states(
    pz.de.WithSideInputsFromInputTuple.handling(
        pz.de.CollectingSideOutputs.handling(multi_effect_model),
        tags=["foo"],
    ),
    category="example_state",
)
handled_multi_effect_model
((result, side_outputs), new_state) = handled_multi_effect_model(((100.0, "value for foo"), initial_state))
print()
pz.show("result:", result)
pz.show("side_outputs:", side_outputs)
pz.show("new_state:", new_state)

You can also have multiple copies of the same effect with different handlers:

multi_side_input_model = pz.nn.Sequential([
    MyLayerWithSideInputs(pz.de.SideInputRequest("foo")),
    MyLayerWithSideInputs(pz.de.SideInputRequest("bar")),
])
handled_multi_side_input_model = pz.de.WithSideInputsFromInputTuple.handling(
    pz.de.WithConstantSideInputs.handling(
        multi_side_input_model,
        side_inputs={"foo": "value_for_foo"}
    ),
    tags=["bar"],
)
handled_multi_side_input_model(("input", "value_for_bar"))

Note that if you try to handle the same effect request twice, the innermost handler will replace it with a reference, so the outer handler won’t replace it. You can always see which handler is going to handle an effect by printing it out in treescope:

some_model = pz.nn.Sequential([
    MyLayerWithSideInputs(pz.de.SideInputRequest("foo")),
])
rehandled_model = pz.de.WithConstantSideInputs.handling(
    pz.de.WithConstantSideInputs.handling(
        some_model,
        side_inputs={"foo": "from inner handler"},
    ),
    side_inputs={"foo": "from outer handler"},
)
rehandled_model
rehandled_model(100)

In some cases, it can be useful to convert one effect into another. For instance, there’s a handler for the random effect that updates its random state using the local state effect:

random_and_state_model = pz.nn.Sequential([
    ExampleLayerWithAState.from_config(category="example_state"),
    SimpleStochasticDropoutV2.from_config(drop_rate=0.1)
])
random_and_state_model
stateful_random_model = pz.de.WithStatefulRandomKey.handling(
    random_and_state_model,
    initial_key=jax.random.key(123),
)
stateful_random_model
pure_random_model, initial_state = pz.de.handle_local_states(
    stateful_random_model,
    category_predicate=lambda _: True,
)
pure_random_model
initial_state
pure_random_model((pz.nx.zeros({"foo": 10}), initial_state))

Patching Models With Effects#

Because of the modularity of the system, you are usually free to insert new logic into models that already have effects, and that new logic can even include new effects. For instance, we can take a model that has random effects:

dropout_mlp = pz.nn.initialize_parameters(
    simple_mlp.DropoutMLP.from_config([8, 32, 32, 8], drop_rate=0.1),
    jax.random.key(123),
)
dropout_mlp

Handle the random effects:

%%autovisualize
dropout_mlp_handled = pz.de.WithRandomKeyFromArg.handling(dropout_mlp)
dropout_mlp_handled((pz.nx.ones({"features": 8}), jax.random.key(0)))

Then inject new logic that requires a new side-output effect to capture intermediates:

patched_dropout_mlp = (
    pz.select(dropout_mlp_handled)
    .at_instances_of(pz.nn.StochasticDropout)
    .insert_after(pz.de.TellIntermediate.from_config(tag="intermediate"))
)
patched_dropout_mlp

And finally handle that new effect:

%%autovisualize
handled_patched_dropout_mlp = pz.de.CollectingSideOutputs.handling(
    patched_dropout_mlp
)
handled_patched_dropout_mlp((pz.nx.ones({"features": 8}), jax.random.key(0)))

If you’re ever unsure about which effects are handled by each handler, you can always figure it out by just printing your model with Treescope. One of the key design goals of Penzai’s data effects system is that you should always be able to figure out what even an effectful model is doing just by looking at it.

handled_patched_dropout_mlp

Sharp Edges of the Effect System#

The effect system has a few sharp edges that you should be aware of.

Broken effect references#

Once you’ve installed a handler for a given effect, the effect requests in your model are replaced with effect references that are specific to the particular handler’s handler ID. This means that if you remove those effectful parts from the model and try to use them on their own, it probably won’t work:

handled_patched_dropout_mlp.body.body.sublayers[1]
try:
  handled_patched_dropout_mlp.body.body.sublayers[1](pz.nx.ones({"features": 8}))
except Exception:
  traceback.print_exc()

For this reason, if you want to pull out individual components from a larger model, it’s usually a good idea to do so before you wrap it with handlers, so that you can handle the effect requests separately after removing the submodel.

If you really need to, however, you are free to manually replace the broken effect references with new effect requests, or even manually re-write the handler IDs. All of the handlers just use string identifiers to determine which refs they should handle, so as long as you set things up consistently, it should work.

Penzai also includes a utility penzai.toolshed.isolate_submodel that can pull out part of a model while also rewriting the built-in effects:

from penzai.toolshed import isolate_submodel
isolate_submodel.call_and_extract_submodel(
    pz.select(handled_patched_dropout_mlp)
      .at(lambda root: root.body.body.sublayers[1]),
    (pz.nx.ones({"features": 8}), jax.random.key(0))
)

JAX transformations inside your model#

Effect implementations are not usually safe to cross JAX transformation boundaries, because they often contain mutable state or external references. This is usually fine, because the only place that effect implementations appear is inside a model object that is actively being called. You can still use Penzai models inside JAX transformations, because the handler always wraps the effects into a pure functional interface.

The one exception is when you want to apply JAX transformations to a small part of your model, but handle the effect outside this transformation. This usually will result in an error. For instance, you can’t do this:

bad = pz.de.WithRandomKeyFromArg.handling(jit_wrapper.Jitted(dropout_mlp))

try:
  bad((pz.nx.ones({"features": 8}), jax.random.key(0)))
except Exception:
  traceback.print_exc()

Instead, you should do this:

dropout_mlp_jitted = jit_wrapper.Jitted(
    pz.de.WithRandomKeyFromArg.handling(dropout_mlp)
)
dropout_mlp_jitted((pz.nx.ones({"features": 8}), jax.random.key(0)))

Effects inside higher-order transformations in Penzai are not yet supported, and the details of this may change in future releases.