check_layer

Contents

check_layer#

penzai.toolshed.check_layers_by_tracing.check_layer(layer: pz.Layer, argument: Any | None = None, initialize: bool = True) Any[source]#

Checks that a layer has been configured correctly by tracing it.

This function runs the layer under jax.eval_shape and passes in a dummy input that matches the expected structure, with arbitrary sizes for unspecified dimension variables. It then checks that the output shape matches the expected output shape.

Note that not all layers may be able to fully encode their preconditions in their input_structure; in this case it may be necessary to provide a dummy argument that matches the expected input structure.

Parameters:
  • layer – The layer to check.

  • argument – An optional argument to pass to the layer. If not provided, a dummy argument will be created based on the layer’s input structure. Can contain jax.ShapeDtypeStruct leaves.

  • initialize – Whether to initialize any uninitialized parameters in the model.

Returns:

The traced output structure from the call, as a PyTree of jax.ShapeDtypeStruct leaves.

Raises:
  • ValueError – If the layer’s input structure contains ANY but no dummy argument was provided.

  • Note that the layer may also raise other exceptions if misconfigured.