Skip to content

Commit 78c5c55

Browse files
authored
Merge 8b570d6 into 5f038ad
2 parents 5f038ad + 8b570d6 commit 78c5c55

2 files changed

Lines changed: 99 additions & 9 deletions

File tree

test/inductor/test_efficient_conv_bn_eval.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,69 @@ def forward(self, x):
9595

9696

9797
class EfficientConvBNEvalTemplate(TestCase):
98+
@tf32_on_and_off(0.003)
99+
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
100+
def test_functional_batch_norm_defaults(self):
101+
class Model(torch.nn.Module):
102+
def forward(self, x, mean, var):
103+
return torch.nn.functional.batch_norm(x, mean, var)
104+
105+
mod = Model().eval()
106+
x = torch.randn(1, 3, 4, 4)
107+
mean = torch.randn(3)
108+
var = torch.abs(torch.randn(3))
109+
110+
device = getattr(self, "device", "cpu")
111+
mod.to(device)
112+
x = x.to(device)
113+
mean = mean.to(device)
114+
var = var.to(device)
115+
116+
opt = torch.compile(mod, backend="inductor")
117+
opt(x, mean, var)
118+
119+
@tf32_on_and_off(0.003)
120+
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
121+
def test_fx_graph_batch_norm_defaults(self):
122+
"""Regression test for issue #169011.
123+
124+
Tests that torch.compile handles FX graphs containing F.batch_norm
125+
with only 3 positional arguments (input, running_mean, running_var).
126+
The original bug was an AssertionError: assert len(bn_node.args) == 8.
127+
"""
128+
from torch.fx import Graph, GraphModule
129+
130+
graph = Graph()
131+
132+
# Create input placeholders
133+
inp = graph.placeholder("input")
134+
mean = graph.placeholder("mean")
135+
var = graph.placeholder("var")
136+
137+
# Create F.batch_norm call with only 3 args (the original bug case)
138+
z = graph.call_function(torch.nn.functional.batch_norm, args=(inp, mean, var))
139+
140+
# Create output node
141+
graph.output(z)
142+
143+
# Wrap in a GraphModule
144+
gm = GraphModule({}, graph)
145+
146+
device = getattr(self, "device", "cpu")
147+
gm.to(device)
148+
gm_compiled = torch.compile(gm, backend="inductor")
149+
150+
inp_tensor = torch.randn(4, 4, device=device)
151+
mean_tensor = torch.randn(4, device=device)
152+
var_tensor = torch.abs(torch.randn(4, device=device)) # Must be positive
153+
154+
# This should not raise AssertionError
155+
out = gm_compiled(inp_tensor, mean_tensor, var_tensor)
156+
157+
# Verify result matches eager evaluation
158+
expected = gm(inp_tensor, mean_tensor, var_tensor)
159+
self.assertEqual(out, expected)
160+
98161
@tf32_on_and_off(0.003)
99162
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
100163
@functorch_config.patch({"enable_autograd_cache": False})

torch/_inductor/fx_passes/efficient_conv_bn_eval.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# mypy: allow-untyped-defs
2+
import inspect
3+
24
import torch
35
import torch.nn as nn
46
from torch._dynamo.utils import counters
@@ -14,6 +16,11 @@
1416
from .pre_grad import efficient_conv_bn_eval_pass
1517

1618

19+
# Cache the signature of F.batch_norm at module load time to avoid repeated
20+
# introspection during graph transformation (fixes performance regression).
21+
_BATCH_NORM_SIGNATURE = inspect.signature(torch.nn.functional.batch_norm)
22+
23+
1724
def efficient_conv_bn_eval(
1825
bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
1926
):
@@ -146,17 +153,37 @@ def efficient_conv_bn_eval_decomposed(
146153
and inductor_config.efficient_conv_bn_eval_fx_passes,
147154
)
148155
def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs):
156+
"""
157+
Graph transformation pass for fusing F.batch_norm with preceding conv operations.
158+
159+
This pass handles F.batch_norm calls with default arguments by normalizing
160+
the args tuple using inspect.signature. It fuses batch normalization with
161+
the preceding convolution for more efficient evaluation.
162+
"""
149163
bn_node = match.nodes[0]
150164
graph = match.graph
151-
assert len(bn_node.args) == 8
165+
166+
# Normalize arguments by binding to cached signature and applying defaults.
167+
# This handles cases where F.batch_norm is called with fewer than 8 args.
168+
bound_args = _BATCH_NORM_SIGNATURE.bind(*bn_node.args, **bn_node.kwargs)
169+
bound_args.apply_defaults()
170+
# Use bound_args.args instead of mutating bn_node.args
171+
normalized_args = bound_args.args
152172

153173
# We can only use efficient conv-bn for eval mode with track_running_stats
154-
# bn_node.args is `training`
155-
if bn_node.args[-3]:
174+
# normalized_args[5] is the "training" argument
175+
training_arg = normalized_args[5]
176+
177+
# Safety check: if 'training' is a symbolic Node (from tracing/export),
178+
# we cannot optimize since we don't know the value at compile time.
179+
if isinstance(training_arg, torch.fx.Node):
180+
return
181+
182+
if training_arg:
156183
return
157184

158185
# Check if the input is Conv
159-
input_node = bn_node.args[0]
186+
input_node = normalized_args[0]
160187

161188
if input_node.op != "call_function": # type: ignore[union-attr]
162189
return
@@ -184,11 +211,11 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
184211

185212
with graph.inserting_before(bn_node):
186213
# prepare args for the fused function
187-
bn_running_mean = bn_node.args[1]
188-
bn_running_var = bn_node.args[2]
189-
bn_weight = bn_node.args[3]
190-
bn_bias = bn_node.args[4]
191-
bn_eps = bn_node.args[7]
214+
bn_running_mean = normalized_args[1]
215+
bn_running_var = normalized_args[2]
216+
bn_weight = normalized_args[3]
217+
bn_bias = normalized_args[4]
218+
bn_eps = normalized_args[7]
192219
assert len(conv_node.args) >= 2 # type: ignore[union-attr]
193220
conv_input = conv_node.args[0] # type: ignore[union-attr]
194221
conv_weight = conv_node.args[1] # type: ignore[union-attr]

0 commit comments

Comments
 (0)