Skip to content

[SPMD][AMP] Invalid sharding for instruction when using AMP in SPMD. #5248

@baoleai

Description

@baoleai

🐛 Bug

When training using AMP in SPMD, it raises error as follows

Non-OK-status: status.status() status: INVALID_ARGUMENT: during context [Unknown]: Invalid sharding for instruction: %add.388 = f32[] add(f32[] %constant.382, f32[] %convert.387), sharding={devices=[1,4]0,1,2,3}: Number of tile assignment dimensions (excluding subgroups) is different than the input rank. sharding={devices=[1,4]0,1,2,3}, input_shape=f32[]

To Reproduce

There is a simple example to reproduce the bug:

test_spmd_amp.py

import numpy as np
import torch
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch import nn
from torch_xla.amp import autocast, GradScaler, syncfree
from torch_xla.experimental.xla_sharding import Mesh


class SimpleLinear(nn.Module):

  def __init__(self):
    super(SimpleLinear, self).__init__()
    self.fc0 = nn.Linear(1024, 1024)
    self.fc1 = nn.Linear(1024, 1024)
    self.fc2 = nn.Linear(1024, 1024)
    self.fc3 = nn.Linear(1024, 1024)
    self.relu = nn.ReLU()


  def forward(self, x):
    h = self.relu(self.fc0(x))
    y = self.relu(self.fc1(h) + x)
    z = self.fc3(y)
    return z

def run():
    device = xm.xla_device()
    model = SimpleLinear().to(device)

    num_devices = 4
    devices_ids = np.arange(num_devices)
    row_mesh = Mesh(devices_ids, (num_devices, 1), ('x', 'y'))
    col_mesh = Mesh(devices_ids, (1, num_devices), ('x', 'y'))

    xs.mark_sharding(model.fc1.weight, col_mesh, (0, 1))
    xs.mark_sharding(model.fc2.weight, row_mesh, (0, 1))

    optimizer = syncfree.AdamW(model.parameters(), lr=0.0, betas=(0.9, 0.999), eps=1e-8)
    scaler = GradScaler()

    for step in range(5):
        dummy_data = torch.zeros(64, 1024, device=device)
        x = dummy_data
        with autocast():
            x = model(x)
            dummy_loss = x.sum()
            scaler.scale(dummy_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            xm.mark_step()
        print(step)


if __name__ == "__main__":
    run()

Steps to reproduce the behavior:

XLA_USE_SPMD=1 torchrun --nproc_per_node=4 examples/test_spmd_amp.py

Environment

  • Reproducible on XLA backend [CPU/TPU]: GPU
  • torch_xla version: master

The error is caused by the code located at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L972. The root cause of the error is that we are attempting to shard a scalar value.
So, one possible solution could be to avoid sharding the scalar value. #5249 @JackCaoG @yeounoh

Metadata

Metadata

Assignees

Labels

distributedSPMD and other distributed things.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions