Skip to content

Commit d0ce1d1

Browse files
jjsjann123pytorchmergebot
authored andcommitted
Nvfuser guard patch
Fixes issue where CudaFusionGuard would return false on backward graph because `requires_grad` flag doesn't match. This is due to the fact that autodiff uses GradMode switch to turn on/off requires_grad, which is not taken into consideration by nvfuser guard. We verified the implementation under `TensorType::matchTensor`. - [x] Add python test to verify no fallback is observed Pull Request resolved: pytorch#75016 Approved by: https://github.com/eellison
1 parent dc71a8f commit d0ce1d1

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

test/test_jit_cuda_fuser.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from torch.nn import functional
13+
from torch.profiler import profile, ProfilerActivity
1314

1415
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR # TEST_WITH_ROCM
1516
from torch.testing._internal.common_cuda import TEST_MULTIGPU
@@ -4273,6 +4274,29 @@ def reduce_scalar(temp):
42734274
reduce_scalar(res).backward()
42744275
torch._C._jit_set_nvfuser_guard_mode(old_guard)
42754276

4277+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
4278+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4279+
"Requires fusion optimization pass to be effective")
4280+
def test_cuda_fusion_guard_backward(self):
4281+
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
4282+
4283+
inp = torch.randn(10, device="cuda", requires_grad=True)
4284+
grad = torch.randn(10, device="cuda")
4285+
4286+
def f(x):
4287+
a = x.cos().cos()
4288+
return a
4289+
scripted = torch.jit.script(f)
4290+
4291+
with profile(activities=[ProfilerActivity.CPU]) as prof:
4292+
for _ in range(5):
4293+
inp.grad = None
4294+
out = scripted(inp)
4295+
out.backward(grad)
4296+
4297+
# check that we do not have fallback triggered
4298+
self.assertEqual(prof.events().table().find("fallback"), -1)
4299+
torch._C._jit_set_nvfuser_guard_mode(old_guard)
42764300

42774301
class TestPassManagerCudaFuser(JitTestCase):
42784302

torch/csrc/jit/codegen/cuda/interface.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ bool skipNode(const std::string& symbol_str, bool flip) {
114114
//! implementation is actually more relaxed.
115115
//!
116116
//! Things that we check:
117-
//! a. identical rank & scalar type
117+
//! a. identical rank & scalar type & device & requires_grad
118+
//! note that: requires_grad is tricky! because autodiff might be marking
119+
//! gradMode to overwrite it. Look at TensorType::matchTensor
120+
//! for the check condition
118121
//! b. stride check:
119122
//! b.1. identical stride order
120123
//! b.2. identical contiguity
@@ -146,7 +149,8 @@ bool complyWith(
146149
(guard_tensor_type->device().has_value() &&
147150
(guard_tensor_type->device().value() != tensor.device())) ||
148151
(guard_tensor_type->requiresGrad().has_value() &&
149-
guard_tensor_type->requiresGrad().value() != tensor.requires_grad())) {
152+
guard_tensor_type->requiresGrad().value() !=
153+
(tensor.requires_grad() && at::GradMode::is_enabled()))) {
150154
return false;
151155
}
152156

0 commit comments

Comments
 (0)