🐛 Bug
AMP autocast failed when set XLA_USE_SPMD=1
To Reproduce
Let's use the following example
import os
import torch
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast
os.environ["XLA_USE_SPMD"] = "1"
device = xm.xla_device()
t1 = torch.ones([2,3], device=device, dtype=torch.float32)
t2 = torch.ones([3,2], device=device, dtype=torch.float32)
with autocast(device, dtype=torch.bfloat16):
t3 = torch.matmul(t1, t2)
print(t3.dtype)
xm.mark_step()
, the expected result is torch.bfloat16, but when setting os.environ["XLA_USE_SPMD"] = "1", the actual output
is torch.float32.
I found that the reason is that when XLA_USE_SPMD is set to 1, the devices in the C++ are all SPMD:0, and PyTorch autocast does not recognize the SPMD:0 device. cc @JackCaoG
Environment
- Reproducible on XLA backend [GPU/TPU]: GPU
- torch_xla version: master
🐛 Bug
AMP autocast failed when set XLA_USE_SPMD=1
To Reproduce
Let's use the following example
, the expected result is
torch.bfloat16, but when settingos.environ["XLA_USE_SPMD"] = "1", the actual outputis
torch.float32.I found that the reason is that when XLA_USE_SPMD is set to 1, the devices in the C++ are all SPMD:0, and PyTorch autocast does not recognize the SPMD:0 device. cc @JackCaoG
Environment