Skip to content

Commit 806010b

Browse files
Rong Rong (AI Infra)facebook-github-bot
authored andcommitted
[BE] move more unittest.main() to run_tests() (#50923)
Summary: Relate to #50483. Everything except ONNX, detectron and release notes tests are moved to use common_utils.run_tests() to ensure CI reports XML correctly. Pull Request resolved: #50923 Reviewed By: samestep Differential Revision: D26027621 Pulled By: walterddr fbshipit-source-id: b04c03f10d1fe96181b720c4c3868e86e4c6281a
1 parent 8690819 commit 806010b

5 files changed

Lines changed: 16 additions & 13 deletions

File tree

test/custom_backend/test_custom_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
import tempfile
33
import torch
4-
import unittest
54

65
from backend import Model, to_custom_backend, get_custom_backend_library_path
6+
from torch.testing._internal.common_utils import TestCase, run_tests
77

88

9-
class TestCustomBackend(unittest.TestCase):
9+
class TestCustomBackend(TestCase):
1010
def setUp(self):
1111
# Load the library containing the custom backend.
1212
self.library_path = get_custom_backend_library_path()
@@ -51,4 +51,4 @@ def test_save_load(self):
5151

5252

5353
if __name__ == "__main__":
54-
unittest.main()
54+
run_tests()

test/custom_operator/test_custom_classes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import glob
66
import os
77

8+
from torch.testing._internal.common_utils import TestCase, run_tests
9+
10+
811
def get_custom_class_library_path():
912
library_filename = glob.glob("build/*custom_class*")
1013
assert (len(library_filename) == 1)
@@ -18,7 +21,7 @@ def test_equality(f, cmp_key):
1821
obj2 = jit.script(f)()
1922
return (cmp_key(obj1), cmp_key(obj2))
2023

21-
class TestCustomOperators(unittest.TestCase):
24+
class TestCustomOperators(TestCase):
2225
def setUp(self):
2326
ops.load_library(get_custom_class_library_path())
2427

@@ -77,4 +80,4 @@ def f():
7780

7881

7982
if __name__ == "__main__":
80-
unittest.main()
83+
run_tests()

test/custom_operator/test_custom_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os.path
22
import tempfile
3-
import unittest
43

54
import torch
65
from torch import ops
76

87
from model import Model, get_custom_op_library_path
8+
from torch.testing._internal.common_utils import TestCase, run_tests
99

1010

11-
class TestCustomOperators(unittest.TestCase):
11+
class TestCustomOperators(TestCase):
1212
def setUp(self):
1313
self.library_path = get_custom_op_library_path()
1414
ops.load_library(self.library_path)
@@ -90,4 +90,4 @@ def test_saving_and_loading_script_module_with_custom_op(self):
9090

9191

9292
if __name__ == "__main__":
93-
unittest.main()
93+
run_tests()

test/mobile/test_lite_script_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import unittest
21
import torch
32
import torch.utils.bundled_inputs
43
from torch.utils.mobile_optimizer import *
@@ -7,8 +6,9 @@
76
from collections import namedtuple
87

98
from torch.jit.mobile import _load_for_lite_interpreter
9+
from torch.testing._internal.common_utils import TestCase, run_tests
1010

11-
class TestLiteScriptModule(unittest.TestCase):
11+
class TestLiteScriptModule(TestCase):
1212

1313
def test_load_mobile_module(self):
1414
class MyTestModule(torch.nn.Module):
@@ -287,4 +287,4 @@ def forward(self):
287287
script_module._save_to_buffer_for_lite_interpreter()
288288

289289
if __name__ == '__main__':
290-
unittest.main()
290+
run_tests()

test/test_tensorexpr_pybind.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
2-
import unittest
32

3+
from torch.testing._internal.common_utils import run_tests
44
from torch.testing._internal.jit_utils import JitTestCase
55

66
class kernel_arena_scope(object):
@@ -37,4 +37,4 @@ def compute(i):
3737
torch.testing.assert_allclose(tA + tB, tC)
3838

3939
if __name__ == '__main__':
40-
unittest.main()
40+
run_tests()

0 commit comments

Comments
 (0)