🐛 Bug
When training using AMP in SPMD, it raises error as follows
Non-OK-status: status.status() status: INVALID_ARGUMENT: during context [Unknown]: Invalid sharding for instruction: %add.388 = f32[] add(f32[] %constant.382, f32[] %convert.387), sharding={devices=[1,4]0,1,2,3}: Number of tile assignment dimensions (excluding subgroups) is different than the input rank. sharding={devices=[1,4]0,1,2,3}, input_shape=f32[]
To Reproduce
There is a simple example to reproduce the bug:
test_spmd_amp.py
import numpy as np
import torch
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch import nn
from torch_xla.amp import autocast, GradScaler, syncfree
from torch_xla.experimental.xla_sharding import Mesh
class SimpleLinear(nn.Module):
def __init__(self):
super(SimpleLinear, self).__init__()
self.fc0 = nn.Linear(1024, 1024)
self.fc1 = nn.Linear(1024, 1024)
self.fc2 = nn.Linear(1024, 1024)
self.fc3 = nn.Linear(1024, 1024)
self.relu = nn.ReLU()
def forward(self, x):
h = self.relu(self.fc0(x))
y = self.relu(self.fc1(h) + x)
z = self.fc3(y)
return z
def run():
device = xm.xla_device()
model = SimpleLinear().to(device)
num_devices = 4
devices_ids = np.arange(num_devices)
row_mesh = Mesh(devices_ids, (num_devices, 1), ('x', 'y'))
col_mesh = Mesh(devices_ids, (1, num_devices), ('x', 'y'))
xs.mark_sharding(model.fc1.weight, col_mesh, (0, 1))
xs.mark_sharding(model.fc2.weight, row_mesh, (0, 1))
optimizer = syncfree.AdamW(model.parameters(), lr=0.0, betas=(0.9, 0.999), eps=1e-8)
scaler = GradScaler()
for step in range(5):
dummy_data = torch.zeros(64, 1024, device=device)
x = dummy_data
with autocast():
x = model(x)
dummy_loss = x.sum()
scaler.scale(dummy_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
xm.mark_step()
print(step)
if __name__ == "__main__":
run()
Steps to reproduce the behavior:
XLA_USE_SPMD=1 torchrun --nproc_per_node=4 examples/test_spmd_amp.py
Environment
- Reproducible on XLA backend [CPU/TPU]: GPU
- torch_xla version: master
The error is caused by the code located at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L972. The root cause of the error is that we are attempting to shard a scalar value.
So, one possible solution could be to avoid sharding the scalar value. #5249 @JackCaoG @yeounoh
🐛 Bug
When training using AMP in SPMD, it raises error as follows
Non-OK-status: status.status() status: INVALID_ARGUMENT: during context [Unknown]: Invalid sharding for instruction: %add.388 = f32[] add(f32[] %constant.382, f32[] %convert.387), sharding={devices=[1,4]0,1,2,3}: Number of tile assignment dimensions (excluding subgroups) is different than the input rank. sharding={devices=[1,4]0,1,2,3}, input_shape=f32[]
To Reproduce
There is a simple example to reproduce the bug:
test_spmd_amp.py
Steps to reproduce the behavior:
Environment
The error is caused by the code located at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L972. The root cause of the error is that we are attempting to shard a scalar value.
So, one possible solution could be to avoid sharding the scalar value. #5249 @JackCaoG @yeounoh