Utility to insert intermediate values into a model’s PyTree.

This module provides a utility that runs a model, collects all intermediate values, and then inserts those intermediate values into the model PyTree itself for ease of visualization and analysis.

The inserted intermediates are represented as identity layers that have an extra array inside them. These identity layers don’t change their input in any way, so the resulting model has the same behavior as the original model.

Note that storing all of the intermediate activations for a large model may use a large amount of memory. This utility is intended primarily for debugging and analyzing small models and small parts of larger models.



No-op annotation that holds onto intermediate activations.


run_and_interleave_intermediates(root, argument)

Interleaves intermediate values into a model.