Skip to content

Commit 3ce1e15

Browse files
Revert "[Dynamo] Support torch.{cuda/cpu}.amp.autocast (#95416)"
This reverts commit c88aa33. Reverted #95416 on behalf of https://github.com/huydhn due to Sorry for reverting your PR. But it seems that the smoke test issue is related as it starts to fail consistently in trunk https://hud.pytorch.org/hud/pytorch/pytorch/master/1?per_page=50&name_filter=inductor_torchbench_smoketest_perf
1 parent 941ff10 commit 3ce1e15

8 files changed

Lines changed: 8 additions & 66 deletions

File tree

functorch/experimental/_cond.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,3 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
247247
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
248248
cond.fallthrough(DispatchKey.ADInplaceOrView)
249249
cond.fallthrough(DispatchKey.BackendSelect)
250-
cond.fallthrough(DispatchKey.AutocastCPU)

functorch/experimental/_map.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,3 @@ def map_functionalize(interpreter, f, xs, *args):
133133
map.fallthrough(DispatchKey.PythonTLSSnapshot)
134134
map.fallthrough(DispatchKey.ADInplaceOrView)
135135
map.fallthrough(DispatchKey.BackendSelect)
136-
map.fallthrough(DispatchKey.AutocastCPU)

test/dynamo/test_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
ALL_DYNAMIC_XFAILS = {
2828
"MiscTests": [
29+
"test_autocast_sdpa",
2930
"test_parsing_sdpa",
3031
],
3132
"ReproTests": [

test/dynamo/test_misc.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3288,51 +3288,10 @@ def forward(self, x):
32883288
self.assertEqual(exported.device.index, 0)
32893289
self.assertEqual(exported.dtype, torch.bfloat16)
32903290

3291-
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
3292-
def test_cuda_amp_autocast(self):
3293-
class MyModule(torch.nn.Module):
3294-
def forward(self, x):
3295-
a_float32 = torch.rand((8, 8), device="cuda")
3296-
b_float32 = torch.rand((8, 8), device="cuda")
3297-
3298-
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
3299-
c_float64 = torch.mm(a_float32, b_float32)
3300-
return c_float64
3301-
3302-
module = MyModule()
3303-
real = module(torch.tensor([0.5]))
3304-
real_device = real.device
3305-
real_dtype = real.dtype
3306-
3307-
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
3308-
exported = graph(torch.tensor([0.5]))
3309-
self.assertEqual(exported.device, real_device)
3310-
self.assertEqual(exported.dtype, real_dtype)
3311-
3312-
self.assertEqual(exported.device.type, "cuda")
3313-
self.assertEqual(exported.device.index, 0)
3314-
self.assertEqual(exported.dtype, torch.float64)
3315-
3316-
def test_is_autocast_cpu_enabled(self):
3317-
def fn(a_float32, b_float32):
3318-
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
3319-
c_float16 = torch.mm(a_float32, b_float32)
3320-
if torch.is_autocast_cpu_enabled():
3321-
c_float16 = c_float16 + 1
3322-
return c_float16
3323-
3324-
a = torch.rand((8, 8))
3325-
b = torch.rand((8, 8))
3326-
ref = fn(a, b)
3327-
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3328-
res = opt_fn(a, b)
3329-
self.assertTrue(same(ref, res))
3330-
33313291
@unittest.skipIf(
33323292
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
33333293
"Can't run fused SDPA on this platform",
33343294
)
3335-
@patch.object(torch._dynamo.config, "dynamic_shapes", False)
33363295
def test_autocast_sdpa(self):
33373296
class MyModule(torch.nn.Module):
33383297
def forward(self, query, key, value):

test/test_jit_autocast.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
import unittest
88
from test_jit import JitTestCase
99
from torch.testing._internal.common_cuda import TEST_CUDA
10-
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
10+
from torch.testing._internal.common_utils import run_tests
1111
from torch.testing import FileCheck
1212
from jit.test_models import MnistNet
1313

1414
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
1515

16-
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
1716
class TestAutocast(JitTestCase):
1817
def setUp(self):
1918
# common input tensors
@@ -758,7 +757,6 @@ def __init__(self, bias_enabled=True):
758757
def forward(self, x):
759758
return self.bn(self.conv(x))
760759

761-
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
762760
class TestJitTraceAutocast(JitTestCase):
763761
def setUp(self):
764762
super().setUp()

torch/_dynamo/allowed_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def _disallowed_function_ids():
9696
torch.autograd.grad,
9797
torch.clear_autocast_cache,
9898
torch.cuda.current_device,
99+
torch.cuda.amp.autocast_mode.autocast,
100+
torch.cpu.amp.autocast_mode.autocast,
99101
torch.distributions.constraints.is_dependent,
100102
torch.distributions.normal.Normal,
101103
torch.inference_mode,

torch/_dynamo/variables/misc.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,19 +280,13 @@ def __init__(self, target_values, initial_values=None, **kwargs):
280280
self.mode = mode
281281

282282
def exit(self, tx, *args):
283-
self.mode = (
284-
exit_functional_autocast(self.mode[0]),
285-
tx.output.create_node(
286-
"call_function", exit_functional_autocast, (self.mode[1],), {}
287-
),
283+
self.mode = tx.output.create_node(
284+
"call_function", exit_functional_autocast, (self.mode,), {}
288285
)
289286

290287
def enter(self, tx):
291-
self.mode = (
292-
enter_functional_autocast(*self.target_values),
293-
tx.output.create_node(
294-
"call_function", enter_functional_autocast, (*self.target_values,), {}
295-
),
288+
self.mode = tx.output.create_node(
289+
"call_function", enter_functional_autocast, (*self.target_values,), {}
296290
)
297291

298292
def module_name(self):

torch/_dynamo/variables/torch.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@
6464
torch.finfo,
6565
torch.get_default_dtype,
6666
torch.iinfo,
67-
torch.is_autocast_cache_enabled,
68-
torch.is_autocast_cpu_enabled,
69-
torch.is_autocast_enabled,
7067
torch.is_floating_point,
7168
torch.nn.functional._Reduction.get_enum,
7269
]
@@ -327,13 +324,6 @@ def call_function(
327324
)
328325
elif self.value is torch.amp.autocast_mode.autocast:
329326
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
330-
elif self.value in [torch.cuda.amp.autocast, torch.cpu.amp.autocast]:
331-
assert "device_type" not in kwargs
332-
if self.value is torch.cuda.amp.autocast:
333-
kwargs.update({"device_type": ConstantVariable("cuda")})
334-
else:
335-
kwargs.update({"device_type": ConstantVariable("cpu")})
336-
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
337327
elif self.value in (
338328
torch.profiler.profile,
339329
torch.profiler.record_function,

0 commit comments

Comments
 (0)