check_layers_by_tracing

check_layers_by_tracing#

Utility to check that a layer has been configured correctly by tracing it.

This module contains a utility for checking that a layer or model has been set up correctly and will execute properly when run, under the assumption that the top-level model or layer has implemented the input_structure method. It does this by running the model under jax.eval_shape and passing in a dummy input that matches the expected structure, with arbitrary sizes for unspecified dimension variables.

Functions

check_layer(layer[, argument, initialize])

Checks that a layer has been configured correctly by tracing it.