Skip to content

[Refactor]Rename NCCL-related items to comm_backend#51061

Merged
edoakes merged 7 commits intoray-project:masterfrom
noemotiovon:nccl-refactory
Jul 8, 2025
Merged

[Refactor]Rename NCCL-related items to comm_backend#51061
edoakes merged 7 commits intoray-project:masterfrom
noemotiovon:nccl-refactory

Conversation

@noemotiovon
Copy link
Copy Markdown
Contributor

@noemotiovon noemotiovon commented Mar 4, 2025

Why are these changes needed?

Background

This PR is a follow-up to #51032, which introduced multi-device support in the Compiled Graph by leveraging CUDA's NCCL backend for efficient out-of-band tensor communication.

While the current implementation is tightly coupled with NCCL and CUDA, the Compiled Graph runtime is now ready to support a broader spectrum of device types and collective communication backends (e.g., HCCL, RCCL).

What This PR Does?

To enable extensibility and backend-agnostic design, this PR introduces the following core changes:

Refactored NCCL-specific naming and APIs
NCCL-related modules, classes, and function names have been generalized to eliminate hardcoded CUDA/NCCL assumptions.

Introduced a pluggable communication backend interface
A unified abstraction layer is added to decouple collective communication logic from any specific implementation. This makes it easier to support alternative collective libraries and device types in the future.

This refactor does not alter the existing behavior of NCCL-based Compiled Graph execution. All current workflows using CUDA+NCCL continue to function as before.

Related issue number

#51574

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@noemotiovon noemotiovon marked this pull request as ready for review March 4, 2025 11:53
@jcotant1 jcotant1 added the core Issues that should be addressed in Ray Core label Mar 4, 2025
@noemotiovon
Copy link
Copy Markdown
Contributor Author

@ruisearch42,
Hi, It’s possible that in the future, the communication backend will not be limited to NCCL, as shown in this PR with HCCL. Therefore, would it be possible to adjust the naming of NCCL?

@noemotiovon noemotiovon marked this pull request as draft March 21, 2025 02:20
@noemotiovon
Copy link
Copy Markdown
Contributor Author

This PR need follow #51574. Make it draft.

@stale
Copy link
Copy Markdown

stale bot commented May 6, 2025

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.

  • If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@stale stale bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label May 6, 2025
@stale stale bot removed the stale The issue is stale. It will be closed within 7 days unless there are further conversation label May 22, 2025
stephanie-wang pushed a commit that referenced this pull request May 22, 2025
)

This PR improves multi-device support in Compile Graph, which
significantly reduces Tensor transmission latency by utilizing
out-of-band communication. Currently, this feature only supports CUDA’s
NCCL. Since Ray already supports multiple accelerators, it is necessary
to extend Compile Graph to support multi-device as well.

This PR mainly introduces two key changes:
1. Removed dependency on cupy.cuda.ExternalStream – Since this library
only supports CUDA devices, we replaced it with a more general stream
context manager to accommodate various accelerators. The new
implementation uses torch.{device}.StreamContext.
2. Replaced hardcoded torch.cuda.xxx calls with AcceleratorRuntime –
This allows automatic detection of the accelerator type and invokes the
appropriate device-specific functions.

### How to add a new backend for CG? here's an example for Ascend NPU:
```python
import ray
import torch
import torch_npu
from ray.dag import InputNode
# implement customer Communicator class
from ray.experimental.channel.hccl_group import _HcclGroup
from ray.experimental.channel.accelerator_context import register_accelerator_context

@ray.remote
class TorchTensorWorker:
    def __init__(self):
        self.device = torch.device('npu:0')
        torch.npu.set_device(self.device)

    def send(self, shape, dtype, value: int):
        return torch.ones(shape, dtype=dtype, device=self.device) * value

    def recv(self, tensor):
        return (tensor[0].item(), tensor.shape, tensor.dtype)

# global register accelerator context
register_accelerator_context('npu', _HcclGroup)

actor_cls = TorchTensorWorker.options(num_cpus=0, resources={'NPU': 1})

sender = actor_cls.remote()
receiver = actor_cls.remote()

with InputNode() as inp:
    dag = sender.send.bind(inp.shape, inp.dtype, inp[0])
    dag = dag.with_tensor_transport(transport='nccl')
    dag = receiver.recv.bind(dag)

shape = (10,)
dtype = torch.float16

compiled_dag = dag.experimental_compile()
for i in range(3):
    ref = compiled_dag.execute(i, shape=shape, dtype=dtype)
    assert ray.get(ref) == (i, shape, dtype)

print("Success")
```

This PR is the main part of Task 2 in #51574 
It would better to set the function name more general, such as changing
requires_nccl to require_communicator. This is implemented in #51061.

Signed-off-by: noemotiovon <757486878@qq.com>
Co-authored-by: noemotiovon <757486878@qq.com>
hipudding added a commit to hipudding/ray that referenced this pull request May 30, 2025
…-project#51032)

This PR improves multi-device support in Compile Graph, which
significantly reduces Tensor transmission latency by utilizing
out-of-band communication. Currently, this feature only supports CUDA's
NCCL. Since Ray already supports multiple accelerators, it is necessary
to extend Compile Graph to support multi-device as well.

This PR mainly introduces two key changes:
1. Removed dependency on cupy.cuda.ExternalStream - Since this library
only supports CUDA devices, we replaced it with a more general stream
context manager to accommodate various accelerators. The new
implementation uses torch.{device}.StreamContext.
2. Replaced hardcoded torch.cuda.xxx calls with AcceleratorRuntime -
This allows automatic detection of the accelerator type and invokes the
appropriate device-specific functions.

```python
import ray
import torch
import torch_npu
from ray.dag import InputNode
from ray.experimental.channel.hccl_group import _HcclGroup
from ray.experimental.channel.accelerator_context import register_accelerator_context

@ray.remote
class TorchTensorWorker:
    def __init__(self):
        self.device = torch.device('npu:0')
        torch.npu.set_device(self.device)

    def send(self, shape, dtype, value: int):
        return torch.ones(shape, dtype=dtype, device=self.device) * value

    def recv(self, tensor):
        return (tensor[0].item(), tensor.shape, tensor.dtype)

register_accelerator_context('npu', _HcclGroup)

actor_cls = TorchTensorWorker.options(num_cpus=0, resources={'NPU': 1})

sender = actor_cls.remote()
receiver = actor_cls.remote()

with InputNode() as inp:
    dag = sender.send.bind(inp.shape, inp.dtype, inp[0])
    dag = dag.with_tensor_transport(transport='nccl')
    dag = receiver.recv.bind(dag)

shape = (10,)
dtype = torch.float16

compiled_dag = dag.experimental_compile()
for i in range(3):
    ref = compiled_dag.execute(i, shape=shape, dtype=dtype)
    assert ray.get(ref) == (i, shape, dtype)

print("Success")
```

This PR is the main part of Task 2 in ray-project#51574
It would better to set the function name more general, such as changing
requires_nccl to require_communicator. This is implemented in ray-project#51061.

Signed-off-by: hipudding <huafengchun@gmail.com>
Co-authored-by: noemotiovon <757486878@qq.com>
@github-actions
Copy link
Copy Markdown

github-actions bot commented Jun 6, 2025

This pull request has been automatically marked as stale because it has not had
any activity for 14 days. It will be closed in another 14 days if no further activity occurs.
Thank you for your contributions.

You can always ask for help on our discussion forum or Ray's public slack channel.

If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@github-actions github-actions bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Jun 6, 2025
@noemotiovon noemotiovon force-pushed the nccl-refactory branch 2 times, most recently from d8b677a to f95998a Compare June 16, 2025 08:48
@noemotiovon
Copy link
Copy Markdown
Contributor Author

This PR is a follow-up to #51032, which introduced multi-device support in the Compiled Graph by leveraging CUDA's NCCL backend for efficient out-of-band tensor communication.

While the current implementation is tightly coupled with NCCL and CUDA, the Compiled Graph runtime is now ready to support a broader spectrum of device types and collective communication backends (e.g., HCCL, RCCL).

@noemotiovon noemotiovon marked this pull request as ready for review June 23, 2025 02:51
@noemotiovon
Copy link
Copy Markdown
Contributor Author

Hi @ruisearch42, @hipudding, this PR aims to generalize the communication backend interface and decouple NCCL-specific logic, as a follow-up to #51032. Happy to hear any feedback or suggestions! 😊

Copy link
Copy Markdown
Contributor

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, initial reviews

@noemotiovon
Copy link
Copy Markdown
Contributor Author

@ruisearch42, Thank you so much for the timely and careful review!
Apologies for the leftover debugging code — that sneaky print wasn’t meant to stick around. 😂
I’ve updated the code based on your comments. Looking forward to your further review!
Have a great day! ☀️

@noemotiovon
Copy link
Copy Markdown
Contributor Author

Hi @ruisearch42
Thanks again for your earlier review! I’ve updated the PR according to the feedback. When you have time, could you please take another look? Appreciate it!

@ruisearch42
Copy link
Copy Markdown
Contributor

Oh apologies! Somehow I lost track of this, will review it again tomorrow.

Copy link
Copy Markdown
Contributor

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. The raised issues are almost all nitpicks.

@noemotiovon
Copy link
Copy Markdown
Contributor Author

Hi @ruisearch42
Thanks for the review, and apologies for the ambiguity caused by my oversight. I've addressed all the comments and corrected the issues accordingly. Looking forward to your continued review!

@ruisearch42 ruisearch42 added the go add ONLY when ready to merge, run all tests label Jul 2, 2025
@ruisearch42
Copy link
Copy Markdown
Contributor

Triggered tests. We can merge after all passes.
cc @stephanie-wang

@ruisearch42
Copy link
Copy Markdown
Contributor

multi-gpu test failed. Please take a look. @noemotiovon

@noemotiovon
Copy link
Copy Markdown
Contributor Author

Hi @ruisearch42 ,
Very sorry! I missed updating the assertions in the test file, which caused the CI failure. I've just fixed it—could you please help me re-trigger the CI? Thanks a lot!
Also, I saw you were working quite late—please don’t overwork yourself and make sure to get some rest. Really appreciate your help! ❤️

This commit is a follow-up to ray-project#51032, which introduced multi-device support
in the Compiled Graph by leveraging CUDA's NCCL backend for efficient out-of-band
tensor communication. While the current implementation is NCCL-specific, the
Compiled Graph runtime is now ready to support a broader range of device types
and collective communication libraries.

To prepare for this generalization, this commit introduces the following changes:

1. Refactored NCCL-specific naming and interfaces
2. Established a pluggable communication backend interface

This refactor does not change the behavior for existing NCCL-based Compiled Graph
execution, but lays the foundation for enabling collective communication across
diverse hardware accelerators and runtime environments.

Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: noemotiovon <757486878@qq.com>
@ruisearch42
Copy link
Copy Markdown
Contributor

thanks for fixing the issues. Triggered gpu test again.

@ruisearch42
Copy link
Copy Markdown
Contributor

@jjyao @edoakes could you help merge? thanks

@edoakes edoakes merged commit 0e93947 into ray-project:master Jul 8, 2025
5 checks passed
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Jul 9, 2025
## Why are these changes needed?

### Background
This PR is a follow-up to
[ray-project#51032](ray-project#51032), which
introduced multi-device support in the Compiled Graph by leveraging
CUDA's NCCL backend for efficient out-of-band tensor communication.

While the current implementation is tightly coupled with NCCL and CUDA,
the Compiled Graph runtime is now ready to support a broader spectrum of
device types and collective communication backends (e.g., HCCL, RCCL).

### What This PR Does?
To enable extensibility and backend-agnostic design, this PR introduces
the following core changes:

Refactored NCCL-specific naming and APIs
NCCL-related modules, classes, and function names have been generalized
to eliminate hardcoded CUDA/NCCL assumptions.

Introduced a pluggable communication backend interface
A unified abstraction layer is added to decouple collective
communication logic from any specific implementation. This makes it
easier to support alternative collective libraries and device types in
the future.

This refactor does not alter the existing behavior of NCCL-based
Compiled Graph execution. All current workflows using CUDA+NCCL continue
to function as before.


## Related issue number
ray-project#51574

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: doyoung <doyoung@anyscale.com>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Jul 9, 2025
## Why are these changes needed?

### Background
This PR is a follow-up to
[ray-project#51032](ray-project#51032), which
introduced multi-device support in the Compiled Graph by leveraging
CUDA's NCCL backend for efficient out-of-band tensor communication.

While the current implementation is tightly coupled with NCCL and CUDA,
the Compiled Graph runtime is now ready to support a broader spectrum of
device types and collective communication backends (e.g., HCCL, RCCL).

### What This PR Does?
To enable extensibility and backend-agnostic design, this PR introduces
the following core changes:

Refactored NCCL-specific naming and APIs
NCCL-related modules, classes, and function names have been generalized
to eliminate hardcoded CUDA/NCCL assumptions.

Introduced a pluggable communication backend interface
A unified abstraction layer is added to decouple collective
communication logic from any specific implementation. This makes it
easier to support alternative collective libraries and device types in
the future.

This refactor does not alter the existing behavior of NCCL-based
Compiled Graph execution. All current workflows using CUDA+NCCL continue
to function as before.


## Related issue number
ray-project#51574

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: doyoung <doyoung@anyscale.com>
ccmao1130 pushed a commit to ccmao1130/ray that referenced this pull request Jul 29, 2025
## Why are these changes needed?

### Background
This PR is a follow-up to
[ray-project#51032](ray-project#51032), which
introduced multi-device support in the Compiled Graph by leveraging
CUDA's NCCL backend for efficient out-of-band tensor communication.

While the current implementation is tightly coupled with NCCL and CUDA,
the Compiled Graph runtime is now ready to support a broader spectrum of
device types and collective communication backends (e.g., HCCL, RCCL).

### What This PR Does?
To enable extensibility and backend-agnostic design, this PR introduces
the following core changes:

Refactored NCCL-specific naming and APIs
NCCL-related modules, classes, and function names have been generalized
to eliminate hardcoded CUDA/NCCL assumptions.

Introduced a pluggable communication backend interface
A unified abstraction layer is added to decouple collective
communication logic from any specific implementation. This makes it
easier to support alternative collective libraries and device types in
the future.

This refactor does not alter the existing behavior of NCCL-based
Compiled Graph execution. All current workflows using CUDA+NCCL continue
to function as before.

## Related issue number
ray-project#51574

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: ChanChan Mao <chanchanmao1130@gmail.com>
dstrodtman pushed a commit to dstrodtman/ray that referenced this pull request Oct 6, 2025
## Why are these changes needed?

### Background
This PR is a follow-up to
[ray-project#51032](ray-project#51032), which
introduced multi-device support in the Compiled Graph by leveraging
CUDA's NCCL backend for efficient out-of-band tensor communication.

While the current implementation is tightly coupled with NCCL and CUDA,
the Compiled Graph runtime is now ready to support a broader spectrum of
device types and collective communication backends (e.g., HCCL, RCCL).

### What This PR Does?
To enable extensibility and backend-agnostic design, this PR introduces
the following core changes:

Refactored NCCL-specific naming and APIs
NCCL-related modules, classes, and function names have been generalized
to eliminate hardcoded CUDA/NCCL assumptions.

Introduced a pluggable communication backend interface
A unified abstraction layer is added to decouple collective
communication logic from any specific implementation. This makes it
easier to support alternative collective libraries and device types in
the future.

This refactor does not alter the existing behavior of NCCL-based
Compiled Graph execution. All current workflows using CUDA+NCCL continue
to function as before.

## Related issue number
ray-project#51574

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: noemotiovon <757486878@qq.com>
Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community core Issues that should be addressed in Ray Core go add ONLY when ready to merge, run all tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants