gradient_checkpointing#
Utilities for gradient checkpointing / rematerialization.
This module provides a wrapper that can be used to rematerialize gradients
through a layer, while correctly accounting for variable states inside the
layer. Rematerialization can be enabled by wrapping a layer in a Remat
block
using something like:
(
pz.select(model)
.at_instances_of(pz.nn.Attention) # or another block
.apply(gradient_checkpointing.Checkpointed)
)
Classes
Wraps a layer to run with gradient checkpointing. |