@@ -1011,59 +1011,49 @@ class CompiledOptimizerBitwiseTests(TestCase):
10111011 def _test_optimizer_bitwise (
10121012 test_case ,
10131013 optim_cls ,
1014+ kernel_count = None ,
10141015 num_steps = 10 ,
10151016 ** optim_kwargs ,
10161017 ):
10171018 """Helper to test optimizer bitwise equality."""
10181019 torch ._dynamo .reset ()
1020+ torch ._inductor .metrics .reset ()
10191021 torch .manual_seed (42 )
10201022
1021- params_eager = [
1022- torch .randn (64 , 64 , device = GPU_TYPE , dtype = torch .float32 ),
1023- torch .randn (32 , 32 , device = GPU_TYPE , dtype = torch .float32 ),
1024- ]
1025- params_compiled = [p .clone () for p in params_eager ]
1026-
1027- opt_eager = optim_cls (
1028- params_eager ,
1029- ** optim_kwargs ,
1030- )
1031- opt_compiled = optim_cls (
1032- params_compiled ,
1033- ** optim_kwargs ,
1023+ input = torch .ones ([10 , 10 ], device = GPU_TYPE )
1024+ model_eager = torch .nn .Sequential (
1025+ * [torch .nn .Linear (10 , 10 , device = GPU_TYPE ) for _ in range (2 )]
10341026 )
1027+ model_eager (input ).sum ().backward ()
10351028
1036- @torch .compile
1037- def compiled_step ():
1038- opt_compiled .step ()
1029+ model_compiled = deepcopy (model_eager )
1030+ model_compiled (input ).sum ().backward ()
10391031
1040- for step in range (num_steps ):
1041- # Generate gradients with consistent seed
1042- torch .manual_seed (1000 + step )
1043- grads = [torch .randn_like (p ) for p in params_eager ]
1044-
1045- for p , g in zip (params_eager , grads ):
1046- p .grad = g .clone ()
1047- for p , g in zip (params_compiled , grads ):
1048- p .grad = g .clone ()
1032+ opt_eager = optim_cls (model_eager .parameters (), ** optim_kwargs )
1033+ opt_compiled = optim_cls (model_compiled .parameters (), ** optim_kwargs )
1034+ compiled_step = compile_opt (opt_compiled )
10491035
1050- opt_eager .step ()
1051- compiled_step ()
1036+ with torch .set_grad_enabled (False ):
1037+ for step in range (num_steps ):
1038+ compiled_step ()
1039+ opt_eager .step ()
10521040
1053- # Check bitwise equality
1054- for i , (p_eager , p_compiled ) in enumerate (
1055- zip (params_eager , params_compiled )
1056- ):
1057- test_case .assertEqual (
1058- p_eager ,
1059- p_compiled ,
1060- atol = 0 ,
1061- rtol = 0 ,
1062- msg = f"Step { step + 1 } , param { i } : params differ" ,
1063- )
1041+ # Check bitwise equality
1042+ for i , (p_eager , p_compiled ) in enumerate (
1043+ zip (model_eager . parameters (), model_compiled . parameters () )
1044+ ):
1045+ test_case .assertEqual (
1046+ p_eager ,
1047+ p_compiled ,
1048+ atol = 0 ,
1049+ rtol = 0 ,
1050+ msg = f"Step { step + 1 } , param { i } : params differ" ,
1051+ )
10641052
10651053 # Also check optimizer state
1066- for p_eager , p_compiled in zip (params_eager , params_compiled ):
1054+ for p_eager , p_compiled in zip (
1055+ model_eager .parameters (), model_compiled .parameters ()
1056+ ):
10671057 for key in opt_eager .state [p_eager ]:
10681058 eager_val = opt_eager .state [p_eager ][key ]
10691059 compiled_val = opt_compiled .state [p_compiled ][key ]
@@ -1076,6 +1066,14 @@ def compiled_step():
10761066 msg = f"State '{ key } ' differs" ,
10771067 )
10781068
1069+ if kernel_count is not None and test_case .check_kernel_count :
1070+ if isinstance (kernel_count , types .LambdaType ):
1071+ kernel_count (str (torch ._inductor .metrics .generated_kernel_count ))
1072+ else :
1073+ test_case .assertEqual (
1074+ torch ._inductor .metrics .generated_kernel_count , kernel_count
1075+ )
1076+
10791077
10801078for optim_cls , name , kwargs , scheduler_cls in COMPILED_OPT_KWARG_DB :
10811079 setattr (
@@ -1085,19 +1083,20 @@ def compiled_step():
10851083 )
10861084
10871085
1088- def _make_bitwise_test (optim_cls , ** optim_kwargs ):
1086+ def _make_bitwise_test (optim_cls , kernel_count = None , ** optim_kwargs ):
10891087 @skipIfRocm (msg = "ROCm may have different numerical behavior" )
10901088 @requires_cuda_and_triton
10911089 @config .patch (
10921090 {
1091+ "score_fusion_memory_threshold" : 1 ,
10931092 "eager_numerics.division_rounding" : True ,
10941093 "eager_numerics.use_pytorch_libdevice" : True ,
10951094 "emulate_precision_casts" : True ,
10961095 }
10971096 )
10981097 def test_fn (self ):
10991098 CompiledOptimizerBitwiseTests ._test_optimizer_bitwise (
1100- self , optim_cls , ** optim_kwargs
1099+ self , optim_cls , kernel_count = kernel_count , ** optim_kwargs
11011100 )
11021101
11031102 return test_fn
@@ -1117,6 +1116,7 @@ def test_fn(self):
11171116# SGD doesn't support capturable but has no item() calls
11181117# so it compiles without graph breaks and can be tested bitwise.
11191118_BITWISE_NON_CAPTURABLE_OPTIMS = (SGD ,)
1119+
11201120for optim_cls , name , kwargs , scheduler_cls in COMPILED_OPT_KWARG_DB :
11211121 if (
11221122 kwargs .get ("device" ) == GPU_TYPE
@@ -1130,13 +1130,24 @@ def test_fn(self):
11301130 or optim_cls in _BITWISE_NON_CAPTURABLE_OPTIMS
11311131 )
11321132 ):
1133+ bitwise_name = name .replace ("test_" , "test_bitwise_" )
1134+ # Use the same kernel count as the non-bitwise test, including
1135+ # any overrides for specific test configurations.
1136+ if name in KERNEL_COUNT_OVERRIDES :
1137+ kernel_count = KERNEL_COUNT_OVERRIDES [name ]
1138+ else :
1139+ kernel_count = (
1140+ KERNEL_COUNTS [optim_cls ].multitensor
1141+ if kwargs .get ("foreach" , False )
1142+ else KERNEL_COUNTS [optim_cls ].singletensor
1143+ )
11331144 optim_kwargs = {
11341145 k : v for k , v in kwargs .items () if k not in ("device" , "kernel_count" )
11351146 }
11361147 setattr (
11371148 CompiledOptimizerTests ,
1138- name . replace ( "test_" , "test_bitwise_" ) ,
1139- _make_bitwise_test (optim_cls , ** optim_kwargs ),
1149+ bitwise_name ,
1150+ _make_bitwise_test (optim_cls , kernel_count = kernel_count , ** optim_kwargs ),
11401151 )
11411152
11421153
0 commit comments