Skip to content

Commit f249065

Browse files
mlazospytorchmergebot
authored andcommitted
[inductor] Add kernel count verification to bitwise optimizer tests (#177071)
Add kernel count checking to CompiledOptimizerBitwiseTests, matching the pattern used in the non-bitwise CompiledOptimizerTests. Kernel counts are computed from KERNEL_COUNTS directly rather than reusing kwargs["kernel_count"] from COMPILED_OPT_KWARG_DB, since those may be assert_expected_inline lambdas tied to specific source locations. Authored with Claude. Pull Request resolved: #177071 Approved by: https://github.com/tianrengao, https://github.com/karthickai
1 parent c3ec2b2 commit f249065

1 file changed

Lines changed: 54 additions & 43 deletions

File tree

test/inductor/test_compiled_optimizers.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,59 +1011,49 @@ class CompiledOptimizerBitwiseTests(TestCase):
10111011
def _test_optimizer_bitwise(
10121012
test_case,
10131013
optim_cls,
1014+
kernel_count=None,
10141015
num_steps=10,
10151016
**optim_kwargs,
10161017
):
10171018
"""Helper to test optimizer bitwise equality."""
10181019
torch._dynamo.reset()
1020+
torch._inductor.metrics.reset()
10191021
torch.manual_seed(42)
10201022

1021-
params_eager = [
1022-
torch.randn(64, 64, device=GPU_TYPE, dtype=torch.float32),
1023-
torch.randn(32, 32, device=GPU_TYPE, dtype=torch.float32),
1024-
]
1025-
params_compiled = [p.clone() for p in params_eager]
1026-
1027-
opt_eager = optim_cls(
1028-
params_eager,
1029-
**optim_kwargs,
1030-
)
1031-
opt_compiled = optim_cls(
1032-
params_compiled,
1033-
**optim_kwargs,
1023+
input = torch.ones([10, 10], device=GPU_TYPE)
1024+
model_eager = torch.nn.Sequential(
1025+
*[torch.nn.Linear(10, 10, device=GPU_TYPE) for _ in range(2)]
10341026
)
1027+
model_eager(input).sum().backward()
10351028

1036-
@torch.compile
1037-
def compiled_step():
1038-
opt_compiled.step()
1029+
model_compiled = deepcopy(model_eager)
1030+
model_compiled(input).sum().backward()
10391031

1040-
for step in range(num_steps):
1041-
# Generate gradients with consistent seed
1042-
torch.manual_seed(1000 + step)
1043-
grads = [torch.randn_like(p) for p in params_eager]
1044-
1045-
for p, g in zip(params_eager, grads):
1046-
p.grad = g.clone()
1047-
for p, g in zip(params_compiled, grads):
1048-
p.grad = g.clone()
1032+
opt_eager = optim_cls(model_eager.parameters(), **optim_kwargs)
1033+
opt_compiled = optim_cls(model_compiled.parameters(), **optim_kwargs)
1034+
compiled_step = compile_opt(opt_compiled)
10491035

1050-
opt_eager.step()
1051-
compiled_step()
1036+
with torch.set_grad_enabled(False):
1037+
for step in range(num_steps):
1038+
compiled_step()
1039+
opt_eager.step()
10521040

1053-
# Check bitwise equality
1054-
for i, (p_eager, p_compiled) in enumerate(
1055-
zip(params_eager, params_compiled)
1056-
):
1057-
test_case.assertEqual(
1058-
p_eager,
1059-
p_compiled,
1060-
atol=0,
1061-
rtol=0,
1062-
msg=f"Step {step + 1}, param {i}: params differ",
1063-
)
1041+
# Check bitwise equality
1042+
for i, (p_eager, p_compiled) in enumerate(
1043+
zip(model_eager.parameters(), model_compiled.parameters())
1044+
):
1045+
test_case.assertEqual(
1046+
p_eager,
1047+
p_compiled,
1048+
atol=0,
1049+
rtol=0,
1050+
msg=f"Step {step + 1}, param {i}: params differ",
1051+
)
10641052

10651053
# Also check optimizer state
1066-
for p_eager, p_compiled in zip(params_eager, params_compiled):
1054+
for p_eager, p_compiled in zip(
1055+
model_eager.parameters(), model_compiled.parameters()
1056+
):
10671057
for key in opt_eager.state[p_eager]:
10681058
eager_val = opt_eager.state[p_eager][key]
10691059
compiled_val = opt_compiled.state[p_compiled][key]
@@ -1076,6 +1066,14 @@ def compiled_step():
10761066
msg=f"State '{key}' differs",
10771067
)
10781068

1069+
if kernel_count is not None and test_case.check_kernel_count:
1070+
if isinstance(kernel_count, types.LambdaType):
1071+
kernel_count(str(torch._inductor.metrics.generated_kernel_count))
1072+
else:
1073+
test_case.assertEqual(
1074+
torch._inductor.metrics.generated_kernel_count, kernel_count
1075+
)
1076+
10791077

10801078
for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB:
10811079
setattr(
@@ -1085,19 +1083,20 @@ def compiled_step():
10851083
)
10861084

10871085

1088-
def _make_bitwise_test(optim_cls, **optim_kwargs):
1086+
def _make_bitwise_test(optim_cls, kernel_count=None, **optim_kwargs):
10891087
@skipIfRocm(msg="ROCm may have different numerical behavior")
10901088
@requires_cuda_and_triton
10911089
@config.patch(
10921090
{
1091+
"score_fusion_memory_threshold": 1,
10931092
"eager_numerics.division_rounding": True,
10941093
"eager_numerics.use_pytorch_libdevice": True,
10951094
"emulate_precision_casts": True,
10961095
}
10971096
)
10981097
def test_fn(self):
10991098
CompiledOptimizerBitwiseTests._test_optimizer_bitwise(
1100-
self, optim_cls, **optim_kwargs
1099+
self, optim_cls, kernel_count=kernel_count, **optim_kwargs
11011100
)
11021101

11031102
return test_fn
@@ -1117,6 +1116,7 @@ def test_fn(self):
11171116
# SGD doesn't support capturable but has no item() calls
11181117
# so it compiles without graph breaks and can be tested bitwise.
11191118
_BITWISE_NON_CAPTURABLE_OPTIMS = (SGD,)
1119+
11201120
for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB:
11211121
if (
11221122
kwargs.get("device") == GPU_TYPE
@@ -1130,13 +1130,24 @@ def test_fn(self):
11301130
or optim_cls in _BITWISE_NON_CAPTURABLE_OPTIMS
11311131
)
11321132
):
1133+
bitwise_name = name.replace("test_", "test_bitwise_")
1134+
# Use the same kernel count as the non-bitwise test, including
1135+
# any overrides for specific test configurations.
1136+
if name in KERNEL_COUNT_OVERRIDES:
1137+
kernel_count = KERNEL_COUNT_OVERRIDES[name]
1138+
else:
1139+
kernel_count = (
1140+
KERNEL_COUNTS[optim_cls].multitensor
1141+
if kwargs.get("foreach", False)
1142+
else KERNEL_COUNTS[optim_cls].singletensor
1143+
)
11331144
optim_kwargs = {
11341145
k: v for k, v in kwargs.items() if k not in ("device", "kernel_count")
11351146
}
11361147
setattr(
11371148
CompiledOptimizerTests,
1138-
name.replace("test_", "test_bitwise_"),
1139-
_make_bitwise_test(optim_cls, **optim_kwargs),
1149+
bitwise_name,
1150+
_make_bitwise_test(optim_cls, kernel_count=kernel_count, **optim_kwargs),
11401151
)
11411152

11421153

0 commit comments

Comments
 (0)