Skip to content

Commit e3c5c36

Browse files
Catherine Leepytorchmergebot
authored andcommitted
Run tests in USE_PYTEST_LIST through run_tests (pytorch#95659)
Part of my effort to move everything to pytest and decrease the number of testrunner frameworks in ci Gives xmls but they might look a weird b/c module level tests vs tests in classes. Doesn't give skip/disable test infra because those are tied to classes. (for future ref, could either put tests in classes or move the check_if_enable stuff into a pytest hook) Tested in CI and checked that the same number of tests are run Pull Request resolved: pytorch#95659 Approved by: https://github.com/huydhn
1 parent e5b9d98 commit e3c5c36

29 files changed

Lines changed: 140 additions & 29 deletions

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ addopts =
1111
testpaths =
1212
test
1313
junit_logging_reruns = all
14+
filterwarnings =
15+
ignore:Module already imported so cannot be rewritten.*hypothesis:pytest.PytestAssertRewriteWarning

test/distributed/elastic/events/lib_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import json
1111
import logging
12-
import unittest
1312
from dataclasses import asdict
1413
from unittest.mock import patch
1514

@@ -21,10 +20,10 @@
2120
_get_or_create_logger,
2221
construct_and_record_rdzv_event,
2322
)
24-
from torch.testing._internal.common_utils import run_tests
23+
from torch.testing._internal.common_utils import run_tests, TestCase
2524

2625

27-
class EventLibTest(unittest.TestCase):
26+
class EventLibTest(TestCase):
2827
def assert_event(self, actual_event, expected_event):
2928
self.assertEqual(actual_event.name, expected_event.name)
3029
self.assertEqual(actual_event.source, expected_event.source)
@@ -59,7 +58,7 @@ def test_event_deser(self):
5958
deser_event = Event.deserialize(json_event)
6059
self.assert_event(event, deser_event)
6160

62-
class RdzvEventLibTest(unittest.TestCase):
61+
class RdzvEventLibTest(TestCase):
6362
@patch("torch.distributed.elastic.events.record_rdzv_event")
6463
@patch("torch.distributed.elastic.events.get_logging_handler")
6564
def test_construct_and_record_rdzv_event(self, get_mock, record_mock):

test/distributed/pipeline/sync/skip/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import nn
1212

1313
from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash
14+
from torch.testing._internal.common_utils import run_tests
1415

1516

1617
def test_namespace_difference():
@@ -45,3 +46,7 @@ def forward(self, x):
4546
))
4647
""".strip()
4748
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

test/distributed/pipeline/sync/skip/test_gpipe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.distributed.pipeline.sync.skip import pop, skippable, stash
1515
from torch.distributed.pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange
1616
from torch.distributed.pipeline.sync.utils import partition_model
17+
from torch.testing._internal.common_utils import run_tests
1718

1819

1920
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@@ -108,3 +109,7 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=None):
108109

109110
output.local_value().sum().backward()
110111
assert input.grad.mean().item() == 1
112+
113+
114+
if __name__ == "__main__":
115+
run_tests()

test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash
1212
from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout
13+
from torch.testing._internal.common_utils import run_tests
1314

1415

1516
class Pass(nn.Module):
@@ -111,3 +112,7 @@ def test_namespace():
111112

112113
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
113114
assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]
115+
116+
117+
if __name__ == "__main__":
118+
run_tests()

test/distributed/pipeline/sync/skip/test_leak.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch.distributed.pipeline.sync import Pipe, is_checkpointing, is_recomputing
1414
from torch.distributed.pipeline.sync.skip import pop, skippable, stash
1515
from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker
16+
from torch.testing._internal.common_utils import run_tests
1617

1718

1819
@skippable(stash=["skip"])
@@ -126,3 +127,7 @@ def deny(*args, **kwargs):
126127
model.eval()
127128
with torch.no_grad():
128129
model(input)
130+
131+
132+
if __name__ == "__main__":
133+
run_tests()

test/distributed/pipeline/sync/skip/test_portal.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.distributed.pipeline.sync.dependency import fork, join
1313
from torch.distributed.pipeline.sync.skip.portal import Portal
1414
from torch.distributed.pipeline.sync.stream import default_stream
15+
from torch.testing._internal.common_utils import run_tests
1516

1617

1718
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@@ -155,3 +156,7 @@ def test_tensor_life_3_plus_1(self, new_portal):
155156
another_tensor = torch.rand(1, requires_grad=True)
156157
portal.put_tensor(another_tensor, tensor_life=1)
157158
portal.blue()
159+
160+
161+
if __name__ == "__main__":
162+
run_tests()

test/distributed/pipeline/sync/skip/test_stash_pop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from torch.distributed.pipeline.sync.skip import pop, skippable, stash
1414
from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker
15+
from torch.testing._internal.common_utils import run_tests
1516

1617

1718
@pytest.fixture(autouse=True)
@@ -136,3 +137,7 @@ def forward(self, input):
136137

137138
l1 = Stash()
138139
l1(torch.tensor(42))
140+
141+
142+
if __name__ == "__main__":
143+
run_tests()

test/distributed/pipeline/sync/skip/test_tracker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.distributed.pipeline.sync.skip import pop, skippable, stash
1919
from torch.distributed.pipeline.sync.skip.layout import SkipLayout
2020
from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, SkipTrackerThroughPotals, current_skip_tracker
21+
from torch.testing._internal.common_utils import run_tests
2122

2223

2324
def test_default_skip_tracker():
@@ -127,3 +128,7 @@ def test_tensor_life_with_checkpointing():
127128
with enable_recomputing():
128129
skip_tracker.save(batch, None, "test", tensor)
129130
assert skip_tracker.portals[(None, "test")].tensor_life == 0
131+
132+
133+
if __name__ == "__main__":
134+
run_tests()

test/distributed/pipeline/sync/skip/test_verify_skippables.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch import nn
1111

1212
from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables
13+
from torch.testing._internal.common_utils import run_tests
1314

1415

1516
def test_matching():
@@ -152,3 +153,7 @@ class Layer4(nn.Module):
152153
verify_skippables(
153154
nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),)
154155
)
156+
157+
158+
if __name__ == "__main__":
159+
run_tests()

0 commit comments

Comments
 (0)