|
1 | 1 | # mypy: allow-untyped-defs |
| 2 | +import inspect |
| 3 | + |
2 | 4 | import torch |
3 | 5 | import torch.nn as nn |
4 | 6 | from torch._dynamo.utils import counters |
|
14 | 16 | from .pre_grad import efficient_conv_bn_eval_pass |
15 | 17 |
|
16 | 18 |
|
| 19 | +# Cache the signature of F.batch_norm at module load time to avoid repeated |
| 20 | +# introspection during graph transformation (fixes performance regression). |
| 21 | +_BATCH_NORM_SIGNATURE = inspect.signature(torch.nn.functional.batch_norm) |
| 22 | + |
| 23 | + |
17 | 24 | def efficient_conv_bn_eval( |
18 | 25 | bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor |
19 | 26 | ): |
@@ -146,17 +153,37 @@ def efficient_conv_bn_eval_decomposed( |
146 | 153 | and inductor_config.efficient_conv_bn_eval_fx_passes, |
147 | 154 | ) |
148 | 155 | def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): |
| 156 | + """ |
| 157 | + Graph transformation pass for fusing F.batch_norm with preceding conv operations. |
| 158 | +
|
| 159 | + This pass handles F.batch_norm calls with default arguments by normalizing |
| 160 | + the args tuple using inspect.signature. It fuses batch normalization with |
| 161 | + the preceding convolution for more efficient evaluation. |
| 162 | + """ |
149 | 163 | bn_node = match.nodes[0] |
150 | 164 | graph = match.graph |
151 | | - assert len(bn_node.args) == 8 |
| 165 | + |
| 166 | + # Normalize arguments by binding to cached signature and applying defaults. |
| 167 | + # This handles cases where F.batch_norm is called with fewer than 8 args. |
| 168 | + bound_args = _BATCH_NORM_SIGNATURE.bind(*bn_node.args, **bn_node.kwargs) |
| 169 | + bound_args.apply_defaults() |
| 170 | + # Use bound_args.args instead of mutating bn_node.args |
| 171 | + normalized_args = bound_args.args |
152 | 172 |
|
153 | 173 | # We can only use efficient conv-bn for eval mode with track_running_stats |
154 | | - # bn_node.args is `training` |
155 | | - if bn_node.args[-3]: |
| 174 | + # normalized_args[5] is the "training" argument |
| 175 | + training_arg = normalized_args[5] |
| 176 | + |
| 177 | + # Safety check: if 'training' is a symbolic Node (from tracing/export), |
| 178 | + # we cannot optimize since we don't know the value at compile time. |
| 179 | + if isinstance(training_arg, torch.fx.Node): |
| 180 | + return |
| 181 | + |
| 182 | + if training_arg: |
156 | 183 | return |
157 | 184 |
|
158 | 185 | # Check if the input is Conv |
159 | | - input_node = bn_node.args[0] |
| 186 | + input_node = normalized_args[0] |
160 | 187 |
|
161 | 188 | if input_node.op != "call_function": # type: ignore[union-attr] |
162 | 189 | return |
@@ -184,11 +211,11 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs |
184 | 211 |
|
185 | 212 | with graph.inserting_before(bn_node): |
186 | 213 | # prepare args for the fused function |
187 | | - bn_running_mean = bn_node.args[1] |
188 | | - bn_running_var = bn_node.args[2] |
189 | | - bn_weight = bn_node.args[3] |
190 | | - bn_bias = bn_node.args[4] |
191 | | - bn_eps = bn_node.args[7] |
| 214 | + bn_running_mean = normalized_args[1] |
| 215 | + bn_running_var = normalized_args[2] |
| 216 | + bn_weight = normalized_args[3] |
| 217 | + bn_bias = normalized_args[4] |
| 218 | + bn_eps = normalized_args[7] |
192 | 219 | assert len(conv_node.args) >= 2 # type: ignore[union-attr] |
193 | 220 | conv_input = conv_node.args[0] # type: ignore[union-attr] |
194 | 221 | conv_weight = conv_node.args[1] # type: ignore[union-attr] |
|
0 commit comments