Skip to content

require_grad not set for the all-reduce output #6319

@aws-rhsoln

Description

@aws-rhsoln

🐛 Bug

Currently with torch-xla-2.1, we see that the output of all-reduce doesn't retain the requires_grad field from input. When I look at pt-2.0 and lower, the all-reduce output used to take the requires_grad value from input tensor here: https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/csrc/init_python_bindings.cpp#L983 however with pt2.1 this piece of code is eliminated.

To Reproduce

Steps to reproduce the behavior:

  1. Install torch-xla=2.1
  2. Run the following piece of code:
import torch
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.utils import is_pjrt_device


if is_pjrt_device():
    import torch_xla.experimental.pjrt_backend
    torch.distributed.init_process_group("xla", init_method="pjrt://")
else:
    torch.distributed.init_process_group("xla")

rank = torch.distributed.get_rank()

if rank == 0:
    tensor = torch.ones(
        (1,1),
        requires_grad=True,
        device=xm.xla_device(),
        dtype=torch.float32,
    )
    _ = xm.all_reduce(xm.REDUCE_SUM, tensor)
else:
    tensor_recv_next = torch.zeros(
        (1,1),
        requires_grad=True,
        device=xm.xla_device(),
        dtype=torch.float32,
    )
    print(f"rank 1 tensor_recv_next before recv requires_grad {tensor_recv_next.requires_grad}")
    tensor_recv_next = xm.all_reduce(xm.REDUCE_SUM, tensor_recv_next)
    print(f"rank 1 tensor_recv_next after recv requires_grad {tensor_recv_next.requires_grad}")
xm.mark_step()
  1. Should see the following output:
rank 1 tensor_recv_next before recv requires_grad True
rank 1 tensor_recv_next after recv requires_grad False

Expected behavior

If you run the above code with 1.13, you should see the following output:

rank 1 tensor_recv_next before recv requires_grad True
rank 1 tensor_recv_next after recv requires_grad True

Environment

  • Reproducible on XLA backend [CPU/TPU]: Neuron
  • torch_xla version: 2.1

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions