🐛 Bug
Currently with torch-xla-2.1, we see that the output of all-reduce doesn't retain the requires_grad field from input. When I look at pt-2.0 and lower, the all-reduce output used to take the requires_grad value from input tensor here: https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/csrc/init_python_bindings.cpp#L983 however with pt2.1 this piece of code is eliminated.
To Reproduce
Steps to reproduce the behavior:
- Install torch-xla=2.1
- Run the following piece of code:
import torch
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.utils import is_pjrt_device
if is_pjrt_device():
import torch_xla.experimental.pjrt_backend
torch.distributed.init_process_group("xla", init_method="pjrt://")
else:
torch.distributed.init_process_group("xla")
rank = torch.distributed.get_rank()
if rank == 0:
tensor = torch.ones(
(1,1),
requires_grad=True,
device=xm.xla_device(),
dtype=torch.float32,
)
_ = xm.all_reduce(xm.REDUCE_SUM, tensor)
else:
tensor_recv_next = torch.zeros(
(1,1),
requires_grad=True,
device=xm.xla_device(),
dtype=torch.float32,
)
print(f"rank 1 tensor_recv_next before recv requires_grad {tensor_recv_next.requires_grad}")
tensor_recv_next = xm.all_reduce(xm.REDUCE_SUM, tensor_recv_next)
print(f"rank 1 tensor_recv_next after recv requires_grad {tensor_recv_next.requires_grad}")
xm.mark_step()
- Should see the following output:
rank 1 tensor_recv_next before recv requires_grad True
rank 1 tensor_recv_next after recv requires_grad False
Expected behavior
If you run the above code with 1.13, you should see the following output:
rank 1 tensor_recv_next before recv requires_grad True
rank 1 tensor_recv_next after recv requires_grad True
Environment
- Reproducible on XLA backend [CPU/TPU]: Neuron
- torch_xla version: 2.1
Additional context
🐛 Bug
Currently with torch-xla-2.1, we see that the output of all-reduce doesn't retain the requires_grad field from input. When I look at pt-2.0 and lower, the all-reduce output used to take the requires_grad value from input tensor here: https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/csrc/init_python_bindings.cpp#L983 however with pt2.1 this piece of code is eliminated.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
If you run the above code with 1.13, you should see the following output:
Environment
Additional context