|
12 | 12 | from torch.testing import FileCheck |
13 | 13 | from torch.testing._internal.common_utils import ( |
14 | 14 | instantiate_parametrized_tests, |
| 15 | + patch_test_members, |
| 16 | + is_navi3_arch, |
15 | 17 | parametrize, |
16 | 18 | TEST_XPU, |
17 | 19 | ) |
@@ -71,31 +73,46 @@ def forward( |
71 | 73 | ) |
72 | 74 | @instantiate_parametrized_tests |
73 | 75 | class TestDecomposeMemMM(TestCase): |
74 | | - def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): |
| 76 | + def __init__(self, method_name='runTest', methodName='runTest'): |
| 77 | + super().__init__(method_name, methodName) |
| 78 | + self.atol = 1e-3 |
| 79 | + self.rtol = 1e-3 |
| 80 | + |
| 81 | + def setup_tolerance(self, rtol=None, atol=None): |
| 82 | + if rtol is None: |
| 83 | + rtol = self.rtol |
| 84 | + if atol is None: |
| 85 | + atol = self.rtol |
| 86 | + |
| 87 | + def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None): |
| 88 | + self.setup_tolerance(rtol, atol) |
75 | 89 | if len(set(ref_dict.keys())) != len(set(res_dict.keys())): |
76 | 90 | return False |
77 | 91 | for key1 in ref_dict.keys(): |
78 | 92 | key2 = "_orig_mod." + key1 |
79 | 93 | assert key2 in res_dict, f"{key1} does not exist in traced module" |
80 | | - if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): |
| 94 | + if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=self.rtol, atol=self.atol): |
81 | 95 | return False |
82 | 96 | return True |
83 | 97 |
|
84 | | - def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3): |
| 98 | + def compare_pred(self, module, traced, input, rtol=None, atol=None): |
| 99 | + self.setup_tolerance(rtol, atol) |
85 | 100 | ref = module(*input) |
86 | 101 | res = traced(*input) |
87 | | - self.assertEqual(ref, res, rtol=rtol, atol=atol) |
| 102 | + self.assertEqual(ref, res, rtol=self.rtol, atol=self.atol) |
88 | 103 |
|
89 | | - def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3): |
| 104 | + def compare_parameters(self, module, traced, rtol=None, atol=None): |
| 105 | + self.setup_tolerance(rtol, atol) |
90 | 106 | ref_params = dict(module.named_parameters()) |
91 | 107 | res_params = dict(traced.named_parameters()) |
92 | | - self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol)) |
| 108 | + self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol=self.rtol, atol=self.atol)) |
93 | 109 |
|
94 | | - def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): |
| 110 | + def compare_gradients(self, module, traced, rtol=None, atol=None): |
| 111 | + self.setup_tolerance(rtol, atol) |
95 | 112 | ref_grad = {key: param.grad for key, param in module.named_parameters()} |
96 | 113 | res_grad = {key: param.grad for key, param in traced.named_parameters()} |
97 | 114 | self.assertTrue( |
98 | | - self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) |
| 115 | + self.compare_dict_tensors(ref_grad, res_grad, rtol=self.rtol, atol=self.atol) |
99 | 116 | ) |
100 | 117 |
|
101 | 118 | @parametrize( |
@@ -202,6 +219,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): |
202 | 219 | ) |
203 | 220 | counters.clear() |
204 | 221 |
|
| 222 | + # We have to increase tolerance for navi3 because all fp16, bf16 |
| 223 | + # GEMMs operations have an accuracy issue caused by hardware limitation |
| 224 | + @patch_test_members({ |
| 225 | + "atol": 2e-3 if is_navi3_arch() else 1e-3, |
| 226 | + "rtol": 2e-3 if is_navi3_arch() else 1e-3 |
| 227 | + }) |
205 | 228 | @parametrize( |
206 | 229 | "m,k,n, should_decompose", |
207 | 230 | [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], |
@@ -310,6 +333,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose): |
310 | 333 | ) |
311 | 334 | counters.clear() |
312 | 335 |
|
| 336 | + # We have to increase tolerance for navi3 because all fp16, bf16 |
| 337 | + # GEMMs operations have an accuracy issue caused by hardware limitation |
| 338 | + @patch_test_members({ |
| 339 | + "atol": 3e-3 if is_navi3_arch() else 1e-3, |
| 340 | + "rtol": 4e-3 if is_navi3_arch() else 1e-3 |
| 341 | + }) |
313 | 342 | @parametrize( |
314 | 343 | "m,k,n, should_decompose", |
315 | 344 | [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], |
|
0 commit comments