Skip to content

Commit a9cca95

Browse files
committed
allowing to do forward pass again in the same autocast mode when checkpointing gradients
1 parent 7c9e78f commit a9cca95

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

torch/utils/checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def forward(ctx, run_function, preserve_rng_state, *args):
5959
check_backward_validity(args)
6060
ctx.run_function = run_function
6161
ctx.preserve_rng_state = preserve_rng_state
62+
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
6263
if preserve_rng_state:
6364
ctx.fwd_cpu_state = torch.get_rng_state()
6465
# Don't eagerly initialize the cuda context by accident.
@@ -91,7 +92,7 @@ def backward(ctx, *args):
9192
if ctx.had_cuda_in_fwd:
9293
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
9394
detached_inputs = detach_variable(inputs)
94-
with torch.enable_grad():
95+
with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
9596
outputs = ctx.run_function(*detached_inputs)
9697

9798
if isinstance(outputs, torch.Tensor):

0 commit comments

Comments
 (0)