gradient_checkpointing

gradient_checkpointing#

Utilities for gradient checkpointing / rematerialization.

This module provides a wrapper that can be used to rematerialize gradients through a layer, while correctly accounting for variable states inside the layer. Rematerialization can be enabled by wrapping a layer in a Remat block using something like:

(
    pz.select(model)
    .at_instances_of(pz.nn.Attention)  # or another block
    .apply(gradient_checkpointing.Checkpointed)
)

Classes

Checkpointed

Wraps a layer to run with gradient checkpointing.