Skip to content

Commit 6c0052a

Browse files
committed
Disable TF32 on DDP tests
1 parent 6514a47 commit 6c0052a

2 files changed

Lines changed: 7 additions & 0 deletions

File tree

test/distributed/test_distributed_fork.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
DistributedTest, TestDistBackend
1212
)
1313

14+
torch.backends.cuda.matmul.allow_tf32 = False
15+
1416
CPP_EXTENSIONS_WARNING = """
1517
Ninja (https://ninja-build.org) must be available to run C++ extensions tests,
1618
but it could not be found. Install ninja with `pip install ninja`
@@ -48,6 +50,7 @@ class TestDistBackendWithFork(TestDistBackend, DistributedTest._DistTestBase):
4850
def setUp(self):
4951
super().setUp()
5052
self._fork_processes()
53+
torch.backends.cudnn.flags(allow_tf32=False).__enter__()
5154

5255

5356
elif BACKEND == "mpi":

test/distributed/test_distributed_spawn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import sys
44
import unittest
55

6+
import torch
67
import torch.distributed as dist
78
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN, NO_MULTIPROCESSING_SPAWN
89
from torch.testing._internal.distributed.distributed_test import (
910
DistributedTest, TestDistBackend
1011
)
1112

13+
torch.backends.cuda.matmul.allow_tf32 = False
14+
1215
if not dist.is_available():
1316
print("Distributed not available, skipping tests", file=sys.stderr)
1417
sys.exit(0)
@@ -28,6 +31,7 @@ class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
2831
def setUp(self):
2932
super().setUp()
3033
self._spawn_processes()
34+
torch.backends.cudnn.flags(allow_tf32=False).__enter__()
3135

3236

3337
if __name__ == "__main__":

0 commit comments

Comments
 (0)