Skip to content

torch::autograd::python::PythonEngine::thread_init crash at exit #50893

@zasdfgbnm

Description

@zasdfgbnm

🐛 Bug

import torch
torch.randn(5, requires_grad=True).sum().backward()
print("done")

crashes at exit and get an error on python 3.9.1

done
terminate called recursively
terminate called after throwing an instance of 'std::runtime_error'
Aborted (core dumped)

The stacktrace showed by gdb is

Starting program: /usr/bin/python crash.py
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/libthread_db.so.1".
[New Thread 0x7fff326cb640 (LWP 370656)]
[New Thread 0x7fff31eca640 (LWP 370657)]
[New Thread 0x7fff316c9640 (LWP 370658)]
[Switching to Thread 0x7fff316c9640 (LWP 370658)]

Thread 4 "python" hit Catchpoint 1 (exception thrown), 0x00007fffd756b072 in __cxxabiv1::__cxa_throw (obj=0x7fff10005bf0, tinfo=0x7fffd7697280 <typeinfo for std::runtime_error>, dest=0x7fffd7580fe0 <std::runtime_error::~runtime_error()>)
    at /build/gcc/src/gcc/libstdc++-v3/libsupc++/eh_throw.cc:78
78      /build/gcc/src/gcc/libstdc++-v3/libsupc++/eh_throw.cc: No such file or directory.
(gdb) bt
#0  0x00007fffd756b072 in __cxxabiv1::__cxa_throw (obj=0x7fff10005bf0, tinfo=0x7fffd7697280 <typeinfo for std::runtime_error>, dest=0x7fffd7580fe0 <std::runtime_error::~runtime_error()>) at /build/gcc/src/gcc/libstdc++-v3/libsupc++/eh_throw.cc:78
#1  0x00007fffca3a6501 in pybind11::pybind11_fail (reason=0x7fffcb377c38 "scoped_acquire::dec_ref(): thread state must be current!") at ../cmake/../third_party/pybind11/include/pybind11/detail/common.h:725
#2  0x00007fffca3b0ee4 in pybind11::gil_scoped_acquire::dec_ref (this=0x7fff316c8cb0) at ../cmake/../third_party/pybind11/include/pybind11/pybind11.h:2125
#3  0x00007fffca3b0f8a in pybind11::gil_scoped_acquire::~gil_scoped_acquire (this=0x7fff316c8cb0, __in_chrg=<optimized out>) at ../cmake/../third_party/pybind11/include/pybind11/pybind11.h:2143
#4  0x00007fffca8dd201 in torch::autograd::python::PythonEngine::thread_init (this=0x7fffcbcfdfe0 <torch::autograd::python::PythonEngine::get_python_engine()::engine>, device=1, 
    ready_queue=std::shared_ptr<torch::autograd::ReadyQueue> (use count 2, weak count 0) = {...}, should_increment=true) at ../torch/csrc/autograd/python_engine.cpp:59
#5  0x00007fffbec1bb82 in std::__invoke_impl<void, void (torch::autograd::Engine::*)(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool), torch::autograd::Engine*, int, std::shared_ptr<torch::autograd::ReadyQueue>, bool> (
    __f=@0x555557e91e90: &virtual torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool), __t=@0x555557e91e88: 0x7fffcbcfdfe0 <torch::autograd::python::PythonEngine::get_python_engine()::engine>)
    at /usr/include/c++/10.2.0/bits/invoke.h:73
#6  0x00007fffbec1b9c6 in std::__invoke<void (torch::autograd::Engine::*)(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool), torch::autograd::Engine*, int, std::shared_ptr<torch::autograd::ReadyQueue>, bool> (
    __fn=@0x555557e91e90: &virtual torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)) at /usr/include/c++/10.2.0/bits/invoke.h:95
#7  0x00007fffbec1b82f in std::thread::_Invoker<std::tuple<void (torch::autograd::Engine::*)(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool), torch::autograd::Engine*, int, std::shared_ptr<torch::autograd::ReadyQueue>, bool> >::_M_invoke<0ul, 1ul, 2ul, 3ul, 4ul> (this=0x555557e91e68) at /usr/include/c++/10.2.0/thread:264
#8  0x00007fffbec1b5a2 in std::thread::_Invoker<std::tuple<void (torch::autograd::Engine::*)(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool), torch::autograd::Engine*, int, std::shared_ptr<torch::autograd::ReadyQueue>, bool> >::operator() (
    this=0x555557e91e68) at /usr/include/c++/10.2.0/thread:271
#9  0x00007fffbec1b46a in std::thread::_State_impl<std::thread::_Invoker<std::tuple<void (torch::autograd::Engine::*)(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool), torch::autograd::Engine*, int, std::shared_ptr<torch::autograd::ReadyQueue>, bool> > >::_M_run (this=0x555557e91e60) at /usr/include/c++/10.2.0/thread:215
#10 0x00007fffd7597c24 in std::execute_native_thread_routine (__p=0x555557e91e60) at /build/gcc/src/gcc/libstdc++-v3/src/c++11/thread.cc:80
#11 0x00007ffff79d93e9 in start_thread () from /usr/lib/libpthread.so.0
#12 0x00007ffff7af2293 in clone () from /usr/lib/libc.so.6

Environment

Collecting environment information...
PyTorch version: 1.8.0a0+5e24d88
Is debug build: True
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 10.2.0
Clang version: 11.0.1
CMake version: version 3.18.4

Python version: 3.9 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce RTX 3090
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.0.5
/usr/lib/libcudnn_adv_infer.so.8.0.5
/usr/lib/libcudnn_adv_train.so.8.0.5
/usr/lib/libcudnn_cnn_infer.so.8.0.5
/usr/lib/libcudnn_cnn_train.so.8.0.5
/usr/lib/libcudnn_ops_infer.so.8.0.5
/usr/lib/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.8.0a0+unknown
[pip3] torchvision==0.8.2
[conda] Could not collect

Additional context

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions