Skip to content

implement send and recv using collective_permute#9373

Merged
bfolie merged 7 commits intomasterfrom
bfolie/send-recv-collectives
Aug 25, 2025
Merged

implement send and recv using collective_permute#9373
bfolie merged 7 commits intomasterfrom
bfolie/send-recv-collectives

Conversation

@bfolie
Copy link
Copy Markdown
Collaborator

@bfolie bfolie commented Jun 17, 2025

Comment thread torch_xla/core/xla_model.py
Comment thread test/pjrt/test_collective_ops_tpu.py
Comment thread torch_xla/distributed/xla_backend.py
Comment thread torch_xla/distributed/xla_backend.py
Comment thread torch_xla/distributed/xla_backend.py
Comment thread test/pjrt/test_collective_ops_tpu.py
Comment thread test/pjrt/test_collective_ops_tpu.py Outdated
Comment thread torch_xla/distributed/xla_backend.py
Comment thread torch_xla/distributed/xla_backend.py Outdated
Comment thread torch_xla/distributed/xla_backend.py
@bfolie
Copy link
Copy Markdown
Collaborator Author

bfolie commented Jun 27, 2025

The approach implemented here works for a "pipeline" type operation but does not work for a "permutation" type operation. The way this is commonly done in native pytorch in order to avoid deadlocks is that half of the devices send and the other receive, then they switch roles. What this means is that the sending and receiving tensors must be different, and one half of the devices end up having a different IR than the other half, resulting in a deadlock. I'm still searching for a way around this.

@bfolie
Copy link
Copy Markdown
Collaborator Author

bfolie commented Jul 2, 2025

The only way I was able to make a "permutation" type op (every device sends and every device receives) work is by inserting a sync after each set of send/recv. This is not ideal. It's better than the status quo for TPU, which is that send/recv don't work at all. But since Neuron does have something working I'll defer to you @rpsilva-aws . We can put this on ice until the Send/Recv XLA ops can be called directly.

@rpsilva-aws
Copy link
Copy Markdown
Collaborator

rpsilva-aws commented Jul 2, 2025

Hm, that does complicate things... I have it working on TRN, though I deviated a bit with multi-operands to capture tokens. I'll end up creating a PR for this one, which would build upon the work you had in the prior commits. Actually, TRN has the same limitation for send/recv, requiring a graph break.

Do you think we can merge this PR without the sync since it's working for existing devices (e.g. TRN), and revisit as we figure out the underlying issues with TPU? If you want to defer until the new ops, or we re-raise the need as we bring in our work, both are ok with me.

@bfolie
Copy link
Copy Markdown
Collaborator Author

bfolie commented Jul 2, 2025

Do you think we can merge this PR without the sync since it's working for existing devices (e.g. TRN), and revisit as we figure out the underlying issues with TPU?

There are two tests in the PR, test_send_recv_pipeline and test_send_recv_permute. Without a sync, the former works on TPU (that's what I initially committed) but the latter does not. Does test_send_recv_permute work on TRN without a sync? I would expect it to have the same non-uniform-IR problems, so that would be a surprise and interesting to me if it did work.

I have it working on TRN, though I deviated a bit with multi-operands to capture tokens. I'll end up creating a PR for this one, which would build upon the work you had in the prior commits.

I'd be interested in seeing that

@kvshbg-aws
Copy link
Copy Markdown

kvshbg-aws commented Aug 22, 2025

@bfolie i tried your changes (on top of ToT) on neuron device and also wrote a small test specific to neuron device which uses 2 ranks and does pipeline_cp_ops and permute_cp_ops as they are in your test files, and i even added an additional test to verify if the computation on the receiving rank happens as expected after receiving the tensor. Writing the test functions over here if you want to take a look (the communication part is exactly same as your tests, just changed the way we do asserts on neuron device using the test file since we use torchrun for some neuron specific tests).

I also tried reproducing the cpu xla_op test failures from this PR, and was not able to reproduce those locally

    def test_pipeline_send_recv(self):
        """Test basic pipeline pattern: first half sends to second half"""
        cutoff = self.world_size // 2
        index = xr.global_ordinal()
        
        tensor = torch.tensor([index], dtype=torch.float, device=self.device)
        
        # Pipeline communication: first half sends to second half
        if index < cutoff:
            dist.send(tensor, index + cutoff)
        else:
            dist.recv(tensor, index - cutoff)
        
        # Verify communication correctness
        if index < cutoff:
            expected = torch.tensor([index], dtype=torch.float)
            self.assertTrue(torch.allclose(tensor.cpu(), expected))
        else:
            expected = torch.tensor([index - cutoff], dtype=torch.float)
            self.assertTrue(torch.allclose(tensor.cpu(), expected))

    def test_permute_communication(self):
        """Test permutation pattern: each device sends to next and receives from previous"""
        index = xr.global_ordinal()
        sending_tensor = torch.tensor([index], dtype=torch.float, device=self.device)
        receiving_tensor = torch.tensor([-1.0], dtype=torch.float, device=self.device)
        
        # Ring communication pattern with deadlock avoidance
        if index % 2 == 0:
            dist.send(sending_tensor, (index + 1) % self.world_size)
            dist.recv(receiving_tensor, (index - 1) % self.world_size)
        else:
            dist.recv(receiving_tensor, (index - 1) % self.world_size)
            dist.send(sending_tensor, (index + 1) % self.world_size)
        
        # Verify ring communication correctness
        expected = torch.tensor([(index - 1) % self.world_size], dtype=torch.float)
        self.assertTrue(torch.allclose(receiving_tensor.cpu(), expected))

    def test_pipeline_send_recv_with_computation(self):
        """Test pipeline pattern with computation on received data"""
        cutoff = self.world_size // 2
        index = xr.global_ordinal()
        
        tensor = torch.tensor([index], dtype=torch.float, device=self.device)
        
        # Pipeline with computation on receiver side
        if index < cutoff:
            dist.send(tensor, index + cutoff)
        else:
            dist.recv(tensor, index - cutoff)
            res = tensor + 3.0  # Compute on received data
        
        # Verify results
        if index < cutoff:
            expected = torch.tensor([index], dtype=torch.float)
            self.assertTrue(torch.allclose(tensor.cpu(), expected))
        else:
            expected = torch.tensor([3.0], dtype=torch.float)
            self.assertTrue(torch.allclose(res.cpu(), expected))

cc: @rpsilva-aws

@bfolie
Copy link
Copy Markdown
Collaborator Author

bfolie commented Aug 25, 2025

@bfolie i tried your changes (on top of ToT) on neuron device and also wrote a small test specific to neuron device which uses 2 ranks and does pipeline_cp_ops and permute_cp_ops as they are in your test files, and i even added an additional test to verify if the computation on the receiving rank happens as expected after receiving the tensor.

Good to know. I would expect that this PR as it currently is, with the syncs on both send and receive, would work for all devices. It's not the most efficient but if this works for you then we can go ahead and merge it.

@bfolie bfolie requested a review from pgmoka August 25, 2025 18:45
Copy link
Copy Markdown
Collaborator

@rpsilva-aws rpsilva-aws left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need tests under neuron/ too, but we can do that separately. LGTM.

@bfolie bfolie merged commit 5522c69 into master Aug 25, 2025
24 checks passed
Copy link
Copy Markdown
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants