annotate_shapes

annotate_shapes#

Utility to add shape annotations throughout a model.

This module provides a utility that walks through a model and adds shape annotations at every point in its evaluation. This is useful for debugging and understanding how values flow through the model.

The resulting model is specialized to a specific input shape. However, it will behave identically to the original model as long as it is called with the same input shape.

Classes

CalledWithManyStructures

No-op annotation that indicates values of many different shapes.

Static

Wraps a value so that it is treated as an empty PyTree.

Functions

annotate_shapes(root, dummy_input)

Annotates shapes in a model or layer, for inputs with this input structure.