annotate_shapes

annotate_shapes#

penzai.toolshed.annotate_shapes.annotate_shapes(root: pz.Layer, dummy_input: Any) pz.Layer[source]#

Annotates shapes in a model or layer, for inputs with this input structure.

Parameters:
  • root – The layer or model to annotate.

  • dummy_input – A dummy input of the same structure as the actual input. Can contain anything with shape/dtype attributes, e.g. jax.ShapeDtypeStruct.

Returns:

A copy of the layer or model with shape annotations added before and after every layer that was called during the evaluation of the model.