Skip to content

[SPMD] AMP autocast failed when set XLA_USE_SPMD=1 #5497

@baoleai

Description

@baoleai

🐛 Bug

AMP autocast failed when set XLA_USE_SPMD=1

To Reproduce

Let's use the following example

import os
import torch

import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast


os.environ["XLA_USE_SPMD"] = "1"

device = xm.xla_device()
t1 = torch.ones([2,3], device=device, dtype=torch.float32)
t2 = torch.ones([3,2], device=device, dtype=torch.float32)

with autocast(device, dtype=torch.bfloat16):
  t3 = torch.matmul(t1, t2)

print(t3.dtype)
xm.mark_step()

, the expected result is torch.bfloat16, but when setting os.environ["XLA_USE_SPMD"] = "1", the actual output
is torch.float32.

I found that the reason is that when XLA_USE_SPMD is set to 1, the devices in the C++ are all SPMD:0, and PyTorch autocast does not recognize the SPMD:0 device. cc @JackCaoG

Environment

  • Reproducible on XLA backend [GPU/TPU]: GPU
  • torch_xla version: master

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions