Skip to content

[inductor] [dynamo]index_reduce_ raised AssertionError in assert_functional_graph #144846

@zhejiangxiaomai

Description

@zhejiangxiaomai

🐛 Describe the bug

index_reduce_ will raise assertionError when the input is a view.

mini reproducer:

import torch


class OpWrapperModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, ifm, op_inputs_dict):
        result = ifm.index_reduce_(**op_inputs_dict)
        return result


torch.manual_seed(8450)
ifm_t = torch.randn([4, 34, 64])
ifm = ifm_t[slice(None, None, None), slice(2, None, None), slice(None, None, None)]
index_tensor = torch.randint(low=0, high=34, size=[64])
source_tensor = torch.randn([4, 32, 64])
params = {
    "index": index_tensor,
    "source": source_tensor,
    "dim": 2,
    "reduce": "mean",
    "include_self": False,
}
model = OpWrapperModule()
model_compiled = torch.compile(model, backend="inductor")
result = model_compiled(ifm, params)

ERROR log and trace:

Traceback (most recent call last):
  File "/home/zhenzhao/qnpu/sw_214852/src/rep.py", line 27, in <module>
    result = model_compiled(ifm, params)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1742, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1753, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 573, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1742, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1753, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
    result = self._inner_convert(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
    self._return(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1101, in compile_subgraph
    self.compile_and_call_fx_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2314, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1863, in compile_fx
    return aot_autograd(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py", line 83, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 1155, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 153, in aot_dispatch_base
    fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph(  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 184, in aot_dispatch_base_graph
    copy_count = assert_functional_graph(fw_module.graph)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/functional_utils.py", line 461, in assert_functional_graph
    n.args[0] in placeholders
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: n=copy_, n.args[0]=permute, placeholders={arg2_1, arg0_1, arg1_1}, graph=graph():
    %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=3] = placeholder[target=arg2_1]
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 32, 64], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%scalar_tensor, [4, 32, 64]), kwargs = {})
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%arg0_1, [None, None, %arg2_1], %expand), kwargs = {})
    %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([4, 32, 64],), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1, 2]), kwargs = {})
    %copy_ : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%permute, %index_put), kwargs = {})
    %full_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 32, 64], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    %index_put_1 : [num_users=2] = call_function[target=torch.ops.aten.index_put.default](args = (%full_1, [None, None, %arg2_1], %full, True), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%index_put_1, 1), kwargs = {})
    %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (1.0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt, %scalar_tensor_1, %index_put_1), kwargs = {})
    %index_put_2 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%copy_, [None, None, %arg2_1], %arg1_1, True), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%index_put_2, %where), kwargs = {})
    %copy__1 : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %div), kwargs = {})
    return (copy__1,)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.6.0a0+git30ac7fd
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.5 (ssh://git@github.com/habana-internal/tpc_llvm10 150d2d7c6a8ff8abf0d8ce194d3fac3986b078e6)
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-127-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

cc @bdhirsh @ezyang @chauhang @penguinwu @zou3519 @yf225

Metadata

Metadata

Assignees

Labels

module: aotdispatchumbrella label for AOTAutograd issuesmodule: functionalizationused for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions