Skip to content

Run backward on CPU before importing torch_xla cause future backward on XLA crash #4174

@shunting314

Description

@shunting314

🐛 Bug

I'm integrating torchdynamo with torchxla for training. One weird crash I encountered is:

RuntimeError: 0 <= device.index() && device.index() < static_cast<c10::DeviceIndex>(device_ready_queues_.size()) INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/engine.cpp":1342, please report a bug to PyTorch.

when doing backward on an XLA model.

It turns out that this only happens if dynamo runs a backward pass on CPU before importing torch_xla.

This is definitely not an issue for dynamo. But I'm not sure if it should be an issue for pytorch or torchxla. Assigning it to torchxla for now.

cc @JackCaoG @wconstab @jansel

To Reproduce

from torch import nn
import os

os.environ["GPU_NUM_DEVICES"] = "1"

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 3)

    def forward(self, inp):
        return self.linear(inp)

    def get_example_inputs(self):
        return (torch.rand(10, 2),)

model = LinearModel()
inputs = model.get_example_inputs()
# import torch_xla # uncomment this line to workaround
model(*inputs).sum().backward()

import torch_xla.core.xla_model as xm
xla_device = xm.xla_device()
model = model.to(device=xla_device)
inputs = map(lambda x: x.to(device=xla_device), inputs)
model(*inputs).sum().backward()
print("bye")

The above script will crash with error:

Traceback (most recent call last):
  File "myscripts/xla_model_backward_crash.py", line 27, in <module>
    model(*inputs).sum().backward()
  File "/pytorch/torch/_tensor.py", line 451, in backward
    self, gradient, retain_graph, create_graph, inputs=inputs
  File "/pytorch/torch/autograd/__init__.py", line 199, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
RuntimeError: 0 <= device.index() && device.index() < static_cast<c10::DeviceIndex>(device_ready_queues_.size()) INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/engine.cpp":1342, please report a bug to PyTorch.

Uncomment the line can work around the issue:

# import torch_xla # uncomment this line to workaround

Environment

  • Reproducible on XLA backend: GPU
  • torch_xla version: revision bde1bc6

Metadata

Metadata

Assignees

Labels

dynamonostaleDo not consider for staleness

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions