Skip to content

Commit 37c138f

Browse files
committed
lint
1 parent 861ca48 commit 37c138f

1 file changed

Lines changed: 25 additions & 13 deletions

File tree

test/inductor/test_decompose_mem_bound_mm.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from torch.testing import FileCheck
1313
from torch.testing._internal.common_utils import (
1414
instantiate_parametrized_tests,
15-
patch_test_members,
1615
is_navi3_arch,
1716
parametrize,
17+
patch_test_members,
1818
TEST_XPU,
1919
)
2020
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
@@ -73,7 +73,7 @@ def forward(
7373
)
7474
@instantiate_parametrized_tests
7575
class TestDecomposeMemMM(TestCase):
76-
def __init__(self, method_name='runTest', methodName='runTest'):
76+
def __init__(self, method_name="runTest", methodName="runTest"):
7777
super().__init__(method_name, methodName)
7878
self.atol = 1e-3
7979
self.rtol = 1e-3
@@ -93,7 +93,9 @@ def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None):
9393
for key1 in ref_dict.keys():
9494
key2 = "_orig_mod." + key1
9595
assert key2 in res_dict, f"{key1} does not exist in traced module"
96-
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=self.rtol, atol=self.atol):
96+
if not torch.allclose(
97+
ref_dict[key1], res_dict[key2], rtol=self.rtol, atol=self.atol
98+
):
9799
return False
98100
return True
99101

@@ -107,14 +109,20 @@ def compare_parameters(self, module, traced, rtol=None, atol=None):
107109
self.setup_tolerance(rtol, atol)
108110
ref_params = dict(module.named_parameters())
109111
res_params = dict(traced.named_parameters())
110-
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol=self.rtol, atol=self.atol))
112+
self.assertTrue(
113+
self.compare_dict_tensors(
114+
ref_params, res_params, rtol=self.rtol, atol=self.atol
115+
)
116+
)
111117

112118
def compare_gradients(self, module, traced, rtol=None, atol=None):
113119
self.setup_tolerance(rtol, atol)
114120
ref_grad = {key: param.grad for key, param in module.named_parameters()}
115121
res_grad = {key: param.grad for key, param in traced.named_parameters()}
116122
self.assertTrue(
117-
self.compare_dict_tensors(ref_grad, res_grad, rtol=self.rtol, atol=self.atol)
123+
self.compare_dict_tensors(
124+
ref_grad, res_grad, rtol=self.rtol, atol=self.atol
125+
)
118126
)
119127

120128
@parametrize(
@@ -223,10 +231,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
223231

224232
# We have to increase tolerance for navi3 because all fp16, bf16
225233
# GEMMs operations have an accuracy issue caused by hardware limitation
226-
@patch_test_members({
227-
"atol": 2e-3 if is_navi3_arch() else 1e-3,
228-
"rtol": 2e-3 if is_navi3_arch() else 1e-3
229-
})
234+
@patch_test_members(
235+
{
236+
"atol": 2e-3 if is_navi3_arch() else 1e-3,
237+
"rtol": 2e-3 if is_navi3_arch() else 1e-3,
238+
}
239+
)
230240
@parametrize(
231241
"m,k,n, should_decompose",
232242
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
@@ -337,10 +347,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose):
337347

338348
# We have to increase tolerance for navi3 because all fp16, bf16
339349
# GEMMs operations have an accuracy issue caused by hardware limitation
340-
@patch_test_members({
341-
"atol": 3e-3 if is_navi3_arch() else 1e-3,
342-
"rtol": 4e-3 if is_navi3_arch() else 1e-3
343-
})
350+
@patch_test_members(
351+
{
352+
"atol": 3e-3 if is_navi3_arch() else 1e-3,
353+
"rtol": 4e-3 if is_navi3_arch() else 1e-3,
354+
}
355+
)
344356
@parametrize(
345357
"m,k,n, should_decompose",
346358
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],

0 commit comments

Comments
 (0)