Skip to content

Commit 27a6498

Browse files
nandesukapytorchmergebot
authored andcommitted
Fix FxConverter mutation tracking and scatter_reduce kwargs for FXIR backend (#175860)
Summary: Three fixes for scatter_reduce opinfo e2e tests through the FXIR codegen path: 1. **Fix sym_sum decomposition (export.py)**: `_decompose_sym_sum` assumed `torch.sym_sum` args are in `n.kwargs["args"]`, but the FX tracer records them as positional args in `n.args[0]`. This caused `KeyError: 'args'` during `symbolic_shape_decompose`. 2. **Propagate include_self kwarg (wrapper_fxir.py)**: The FxConverter's `_generate_scatter_fallback` only extracted the `reduce` kwarg, silently dropping `include_self`. This caused `scatter_reduce` with `include_self=False` to default to `True`, producing wrong results. 3. **Track in-place mutations in FxConverter (wrapper_fxir.py)**: For mutation ops like `scatter_reduce_` and `index_put_`, the FxConverter created a new output buffer but never updated `buffer_to_node` for the mutated input. Downstream references to the mutated buffer still pointed to the pre-mutation copy, so the mutation result was completely ignored. Fixed by updating `buffer_to_node` for all mutated names after creating the fallback call node. Test Plan: Validated 13 scatter_reduce opinfo e2e tests across all reduce modes (amax, amin, mean, prod), multiple dtypes (f16, bf16, f32), both include_self values, and both dim values: ``` OPINFO_START_INDEX=8268 OPINFO_END_INDEX=8269 OPINFO_LOWERING_MODE=AFG_INDUCTOR_COMPILE OPINFO_AUTO_DS=1 buck2 run fbcode//mtia/compiler/graph_compiler/test_library/afg_opinfo_e2e_tests:test_afg_opinfo_e2e ``` Pull Request resolved: #175860 Approved by: https://github.com/sidt-meta
1 parent 1f9e1e5 commit 27a6498

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

torch/_inductor/codegen/wrapper_fxir.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DivideByKey,
3333
)
3434
from torch.utils import _pytree as pytree
35+
from torch.utils._ordered_set import OrderedSet
3536
from torch.utils._sympy.functions import FloorDiv
3637
from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
3738
from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis
@@ -880,6 +881,11 @@ def _generate_fallback_call(
880881
)
881882
result_buffer = ir_node.codegen_reference()
882883
self.buffer_to_node[result_buffer] = fx_node
884+
# For in-place mutation ops (e.g., scatter_reduce_, index_put_),
885+
# update the buffer mapping for mutated inputs so downstream
886+
# references to the mutated buffer see the post-mutation node.
887+
for mutated_name in ir_node.get_mutation_names():
888+
self.buffer_to_node[mutated_name] = fx_node
883889

884890
def _generate_index_put_fallback(self, line: WrapperLine) -> None:
885891
assert isinstance(line, IndexPutFallbackLine)
@@ -914,6 +920,15 @@ def _generate_scatter_fallback(self, line: WrapperLine) -> None:
914920
kwargs = {}
915921
if reduce := ir_node.kwargs.get("reduce"):
916922
kwargs["reduce"] = reduce
923+
# Only pass kwargs that the op's schema actually accepts, since
924+
# ScatterFallback stores both reduce and include_self for all
925+
# scatter variants, but not all ops support them (e.g.,
926+
# scatter_.value has no kwargs, scatter_reduce_.two has both).
927+
assert isinstance(ir_node.op_overload, torch._ops.OpOverload)
928+
schema_arg_names = OrderedSet(
929+
[a.name for a in ir_node.op_overload._schema.arguments]
930+
)
931+
kwargs = {k: v for k, v in ir_node.kwargs.items() if k in schema_arg_names}
917932

918933
self._generate_fallback_call(ir_node, args, kwargs)
919934

0 commit comments

Comments
 (0)