Skip to content

torch.compile crash - Aborted exit code 134 #125804

@atalman

Description

@atalman

🐛 Describe the bug

Minirepro:

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28, 1)

    def forward(self, x):
        output = self.fc1(x)
        return output


x = torch.rand(28, 28, device="cuda")
model = Net().to(device="cuda")
x_pt2 = torch.compile(model, mode="max-autotune")(x)
try:
    torch._assert_async(torch.tensor(0, device="cuda"))
except:
    print("ignoring exception")
    
# check for `Aborted (core dumped)` on process exit
(segfault) [16:07:08] ~/builder/test/smoke_test (main) > python minirepro.py
/home/xmfan/.conda/envs/segfault/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
../aten/src/ATen/native/cuda/TensorCompare.cu:106: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
Aborted (core dumped)

stack trace: https://gist.github.com/xmfan/d2dcddda2f042df35832992753e3df34

#0  0x00007ffff7c8b94c in __pthread_kill_implementation () from /lib64/libc.so.6
#1  0x00007ffff7c3e646 in raise () from /lib64/libc.so.6
#2  0x00007ffff7c287f3 in abort () from /lib64/libc.so.6
#3  0x00007ffff66b135a in __cxxabiv1::__terminate (handler=<optimized out>) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:48
#4  0x00007ffff66b03b9 in __cxa_call_terminate (ue_header=0x14225d90) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_call.cc:54
#5  0x00007ffff66b0ae7 in __cxxabiv1::__gxx_personality_v0 (version=<optimized out>, actions=6, exception_class=5138137972254386944, ue_header=0x14225d90, context=<optimized out>) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_personality.cc:685
#6  0x00007ffff74f51e4 in _Unwind_RaiseException_Phase2 (exc=0x14225d90, context=0x7fffffffbdb0, frames_p=0x7fffffffbcb8) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libgcc/unwind.inc:64
#7  0x00007ffff74f5c1e in _Unwind_Resume (exc=0x14225d90) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libgcc/unwind.inc:241
#8  0x00007fffa90015fb in at::CUDAGeneratorState::unregister_graph(at::cuda::CUDAGraph*) [clone .cold] () from /home/xmfan/.conda/envs/segfault/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
#9  0x00007fffa91e1e3c in at::cuda::CUDAGraph::~CUDAGraph() () from /home/xmfan/.conda/envs/segfault/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so

Original description:

Repro:
Install torch with cuda 11.8 or 12.1 for python 3.8-3.11
Set global vars:

MATRIX_GPU_ARCH_VERSION=12.1
MATRIX_GPU_ARCH_TYPE=cuda

Run following python script: https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py

python smoke_test.py --package torchonly

Failure:

True
Testing smoke_test_compile with mode 'max-autotune' for torch.float32
torch cuda: 12.1
torch cudnn: 8902
cuDNN enabled? True
torch nccl version: (2, 20, 5)
Testing test_cuda_runtime_errors_captured
../aten/src/ATen/native/cuda/TensorCompare.cu:106: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
Caught CUDA exception with success: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Aborted (core dumped) 

Failure workflow:
https://github.com/pytorch/builder/actions/runs/9007921255/job/24748864568#step:11:4337

If I comment out this line, no failure is observed:
x_pt2 = torch.compile(model, mode="max-autotune")(x)
https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py#L265C5-L265C57

Started happening on:
2.4.0.dev20240327

This nightly commit:
384cbf2

Versions

nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mcarilli @eellison @peterbell10 @bdhirsh @anijain2305 @chauhang @jansel

Metadata

Metadata

Assignees

Labels

high prioritymodule: cuda graphsAbility to capture and then replay streams of CUDA kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions