As brought up in #9982 (comment) by @bghira, I think we could support this directly in the enable_gradient_checkpointing() method we expose for the models.
Users could specify the block interval they want gradient checkpointing to be applied in and we take care of the rest. The code for this is simple and doesn't require any hacks.
Gradient checkpointing is a crucial component for training/fine-tuning larger models and this technique allows for nice speed/memory trade-off.
Cc: @a-r-r-o-w @hlky