1212from torch .testing import FileCheck
1313from torch .testing ._internal .common_utils import (
1414 instantiate_parametrized_tests ,
15- patch_test_members ,
1615 is_navi3_arch ,
1716 parametrize ,
17+ patch_test_members ,
1818 TEST_XPU ,
1919)
2020from torch .testing ._internal .inductor_utils import GPU_TYPE , HAS_CUDA_AND_TRITON
@@ -73,7 +73,7 @@ def forward(
7373)
7474@instantiate_parametrized_tests
7575class TestDecomposeMemMM (TestCase ):
76- def __init__ (self , method_name = ' runTest' , methodName = ' runTest' ):
76+ def __init__ (self , method_name = " runTest" , methodName = " runTest" ):
7777 super ().__init__ (method_name , methodName )
7878 self .atol = 1e-3
7979 self .rtol = 1e-3
@@ -93,7 +93,9 @@ def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None):
9393 for key1 in ref_dict .keys ():
9494 key2 = "_orig_mod." + key1
9595 assert key2 in res_dict , f"{ key1 } does not exist in traced module"
96- if not torch .allclose (ref_dict [key1 ], res_dict [key2 ], rtol = self .rtol , atol = self .atol ):
96+ if not torch .allclose (
97+ ref_dict [key1 ], res_dict [key2 ], rtol = self .rtol , atol = self .atol
98+ ):
9799 return False
98100 return True
99101
@@ -107,14 +109,20 @@ def compare_parameters(self, module, traced, rtol=None, atol=None):
107109 self .setup_tolerance (rtol , atol )
108110 ref_params = dict (module .named_parameters ())
109111 res_params = dict (traced .named_parameters ())
110- self .assertTrue (self .compare_dict_tensors (ref_params , res_params , rtol = self .rtol , atol = self .atol ))
112+ self .assertTrue (
113+ self .compare_dict_tensors (
114+ ref_params , res_params , rtol = self .rtol , atol = self .atol
115+ )
116+ )
111117
112118 def compare_gradients (self , module , traced , rtol = None , atol = None ):
113119 self .setup_tolerance (rtol , atol )
114120 ref_grad = {key : param .grad for key , param in module .named_parameters ()}
115121 res_grad = {key : param .grad for key , param in traced .named_parameters ()}
116122 self .assertTrue (
117- self .compare_dict_tensors (ref_grad , res_grad , rtol = self .rtol , atol = self .atol )
123+ self .compare_dict_tensors (
124+ ref_grad , res_grad , rtol = self .rtol , atol = self .atol
125+ )
118126 )
119127
120128 @parametrize (
@@ -223,10 +231,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
223231
224232 # We have to increase tolerance for navi3 because all fp16, bf16
225233 # GEMMs operations have an accuracy issue caused by hardware limitation
226- @patch_test_members ({
227- "atol" : 2e-3 if is_navi3_arch () else 1e-3 ,
228- "rtol" : 2e-3 if is_navi3_arch () else 1e-3
229- })
234+ @patch_test_members (
235+ {
236+ "atol" : 2e-3 if is_navi3_arch () else 1e-3 ,
237+ "rtol" : 2e-3 if is_navi3_arch () else 1e-3 ,
238+ }
239+ )
230240 @parametrize (
231241 "m,k,n, should_decompose" ,
232242 [(20480 , 5 , 2 , True ), (20480 , 32 , 2 , False ), (2048 , 2 , 2 , False )],
@@ -337,10 +347,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose):
337347
338348 # We have to increase tolerance for navi3 because all fp16, bf16
339349 # GEMMs operations have an accuracy issue caused by hardware limitation
340- @patch_test_members ({
341- "atol" : 3e-3 if is_navi3_arch () else 1e-3 ,
342- "rtol" : 4e-3 if is_navi3_arch () else 1e-3
343- })
350+ @patch_test_members (
351+ {
352+ "atol" : 3e-3 if is_navi3_arch () else 1e-3 ,
353+ "rtol" : 4e-3 if is_navi3_arch () else 1e-3 ,
354+ }
355+ )
344356 @parametrize (
345357 "m,k,n, should_decompose" ,
346358 [(20480 , 5 , 2 , True ), (20480 , 32 , 2 , False ), (2048 , 2 , 2 , False )],
0 commit comments