Skip to content

Commit 61e89db

Browse files
committed
Avoid nested CommTensor wrapping
[ghstack-poisoned]
1 parent 36d7914 commit 61e89db

4 files changed

Lines changed: 17 additions & 1 deletion

File tree

test/distributed/test_c10d_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,13 @@ def comm_fn(tensor, group=None):
14721472

14731473
self._test_work_wait(tensor, comm_fn=comm_fn)
14741474

1475+
def _test_nested_comm_tensor_wrapping(self, tensor):
1476+
def comm_fn(tensor, group=None):
1477+
work = dist.all_reduce(CommTensor(tensor), group=group, async_op=True)
1478+
return work, tensor
1479+
1480+
self._test_work_wait(tensor, comm_fn=comm_fn)
1481+
14751482

14761483
if __name__ == "__main__":
14771484
assert (

test/distributed/test_c10d_gloo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,6 +2396,9 @@ def test_scatter_work_wait_gpu(self):
23962396
torch.ones(2, 2, device=self.rank) * self.rank
23972397
)
23982398

2399+
def test_nested_comm_tensor_wrapping(self):
2400+
self._test_nested_comm_tensor_wrapping(torch.ones(2, 2) * self.rank)
2401+
23992402

24002403
if __name__ == "__main__":
24012404
assert (

test/distributed/test_c10d_nccl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,6 +2853,11 @@ def test_scatter_work_wait_gpu(self):
28532853
torch.ones(2, 2, device=self.rank) * self.rank
28542854
)
28552855

2856+
@skip_if_lt_x_gpu(2)
2857+
def test_nested_comm_tensor_wrapping(self):
2858+
self._test_nested_comm_tensor_wrapping(
2859+
torch.ones(2, 2, device=self.rank) * self.rank
2860+
)
28562861

28572862
if __name__ == "__main__":
28582863
assert (

torch/distributed/_spmd/comm_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ class CommTensor(torch.Tensor):
9393
def __new__(cls, tensor: torch.Tensor):
9494
r = torch.Tensor._make_subclass(
9595
cls,
96-
tensor,
96+
# avoid nested CommTensor Wrapping
97+
tensor._tensor if isinstance(tensor, CommTensor) else tensor,
9798
require_grad=tensor.requires_grad,
9899
)
99100
# The tensor object wrapped by this CommTensor

0 commit comments

Comments
 (0)