Skip to content

Commit e7f2289

Browse files
iupaikov-amdpytorchmergebot
authored andcommitted
[AMD][gfx1100] test_decompose_mem_bound_mm.py tolerance increase for navi3x(gfx11x)
(cherry picked from commit 03c7da0) Signed-off-by: Artem Kuzmitckii <artem.kuzmitckii@amd.com>
1 parent 1009790 commit e7f2289

2 files changed

Lines changed: 68 additions & 9 deletions

File tree

test/inductor/test_decompose_mem_bound_mm.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from torch.testing import FileCheck
1313
from torch.testing._internal.common_utils import (
1414
instantiate_parametrized_tests,
15+
patch_test_members,
16+
is_navi3_arch,
1517
parametrize,
1618
TEST_XPU,
1719
)
@@ -71,31 +73,46 @@ def forward(
7173
)
7274
@instantiate_parametrized_tests
7375
class TestDecomposeMemMM(TestCase):
74-
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
76+
def __init__(self, method_name='runTest', methodName='runTest'):
77+
super().__init__(method_name, methodName)
78+
self.atol = 1e-3
79+
self.rtol = 1e-3
80+
81+
def setup_tolerance(self, rtol=None, atol=None):
82+
if rtol is None:
83+
rtol = self.rtol
84+
if atol is None:
85+
atol = self.rtol
86+
87+
def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None):
88+
self.setup_tolerance(rtol, atol)
7589
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
7690
return False
7791
for key1 in ref_dict.keys():
7892
key2 = "_orig_mod." + key1
7993
assert key2 in res_dict, f"{key1} does not exist in traced module"
80-
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
94+
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=self.rtol, atol=self.atol):
8195
return False
8296
return True
8397

84-
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
98+
def compare_pred(self, module, traced, input, rtol=None, atol=None):
99+
self.setup_tolerance(rtol, atol)
85100
ref = module(*input)
86101
res = traced(*input)
87-
self.assertEqual(ref, res, rtol=rtol, atol=atol)
102+
self.assertEqual(ref, res, rtol=self.rtol, atol=self.atol)
88103

89-
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
104+
def compare_parameters(self, module, traced, rtol=None, atol=None):
105+
self.setup_tolerance(rtol, atol)
90106
ref_params = dict(module.named_parameters())
91107
res_params = dict(traced.named_parameters())
92-
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
108+
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol=self.rtol, atol=self.atol))
93109

94-
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
110+
def compare_gradients(self, module, traced, rtol=None, atol=None):
111+
self.setup_tolerance(rtol, atol)
95112
ref_grad = {key: param.grad for key, param in module.named_parameters()}
96113
res_grad = {key: param.grad for key, param in traced.named_parameters()}
97114
self.assertTrue(
98-
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
115+
self.compare_dict_tensors(ref_grad, res_grad, rtol=self.rtol, atol=self.atol)
99116
)
100117

101118
@parametrize(
@@ -202,6 +219,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
202219
)
203220
counters.clear()
204221

222+
# We have to increase tolerance for navi3 because all fp16, bf16
223+
# GEMMs operations have an accuracy issue caused by hardware limitation
224+
@patch_test_members({
225+
"atol": 2e-3 if is_navi3_arch() else 1e-3,
226+
"rtol": 2e-3 if is_navi3_arch() else 1e-3
227+
})
205228
@parametrize(
206229
"m,k,n, should_decompose",
207230
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
@@ -310,6 +333,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose):
310333
)
311334
counters.clear()
312335

336+
# We have to increase tolerance for navi3 because all fp16, bf16
337+
# GEMMs operations have an accuracy issue caused by hardware limitation
338+
@patch_test_members({
339+
"atol": 3e-3 if is_navi3_arch() else 1e-3,
340+
"rtol": 4e-3 if is_navi3_arch() else 1e-3
341+
})
313342
@parametrize(
314343
"m,k,n, should_decompose",
315344
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],

torch/testing/_internal/common_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@
100100
except ImportError:
101101
has_pytest = False
102102

103-
104103
SEED = 1234
105104
MI350_ARCH = ("gfx950",)
106105
MI300_ARCH = ("gfx942",)
@@ -134,6 +133,14 @@ class ProfilingMode(Enum):
134133
UNITTEST_ARGS : list[str] = []
135134
USE_PYTEST = False
136135

136+
def is_navi3_arch():
137+
if torch.cuda.is_available():
138+
prop = torch.cuda.get_device_properties(0)
139+
gfx_arch = prop.gcnArchName.split(":")[0]
140+
if gfx_arch in NAVI3_ARCH:
141+
return True
142+
return False
143+
137144
def freeze_rng_state(*args, **kwargs):
138145
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
139146

@@ -5848,3 +5855,26 @@ def wrap_fn(self, *args, **kwargs):
58485855
raise unittest.SkipTest("Python version mismatch")
58495856
return wrap_fn
58505857
return dec_fn
5858+
5859+
# Decorator to patch multiple test class members for the duration of the subtest
5860+
def patch_test_members(updates: Dict[str, Any]):
5861+
def decorator(test_func):
5862+
@wraps(test_func)
5863+
def wrapper(self, *args, **kwargs):
5864+
# Store the original values of the specified members
5865+
original_values = {member: getattr(self, member) for member in updates}
5866+
5867+
# Update the members before running the subtest
5868+
for member, value in updates.items():
5869+
setattr(self, member, value)
5870+
5871+
# Run the test function, allowing subtests to run
5872+
try:
5873+
return test_func(self, *args, **kwargs)
5874+
finally:
5875+
# Restore the original values of the specified members after the subtest finishes
5876+
for member, original_value in original_values.items():
5877+
setattr(self, member, original_value)
5878+
5879+
return wrapper
5880+
return decorator

0 commit comments

Comments
 (0)