torch.utils.checkpoint.checkpoint + torch.cuda.amp #40221
torch.utils.checkpoint.checkpoint + torch.cuda.amp #40221tano297 wants to merge 1 commit intopytorch:masterfrom tano297:autocast_grad_checkpoint
Conversation
…kpointing gradients
💊 CI failures summary and remediationsAs of commit a9cca95 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 1 time. |
|
@mcarilli I requested review from you because of the mention of amp, please let me know if that's not right |
|
Any update on this? Let me know how I can help |
|
Is this bug solved? I have met the same issue. |
mcarilli
left a comment
There was a problem hiding this comment.
Looks good to me, thanks!
I think the usual custom autograd function decorators aren't preferable here, because CheckpointFunction.backward runs a nested forward and backward. The autocast API recommends running only forward under autocast, but globally enabling autocast for all of CheckpointFunction.backward (as @custom_bwd might do) would include the nested backward as well.
Definitely needs a test though.
|
PR appears orphaned, moving to #49757. |
Summary: Adds a test to orphaned original PR (pytorch#40221). Should fix pytorch#49738 and pytorch#47183 Pull Request resolved: pytorch#49757 Reviewed By: mruberry Differential Revision: D25689609 Pulled By: ngimel fbshipit-source-id: 0a6adc11eb98382048ef9a9775e185dcdeff6010
Simple 2 line workaround to allow gradient checkpointing to work with amp autocast.
In the same way pytorch stores the "has_cuda" state in the context, we store a "has_autocast" during the first forward pass, so that we can re-enable it when the forward pass runs for the second time during the backward pass.
For anybody having this problem, the simple solution to this problem before this is merged can be found either here: #37730, or by simply copying this version of the file into your own codebase with the added 2 lines.