Skip to content

(WIP) [core][compiled graphs] Unify code paths for NCCL P2P and collectives scheduling#48649

Closed
AndyUB wants to merge 151 commits intoray-project:masterfrom
AndyUB:union-dev-1105
Closed

(WIP) [core][compiled graphs] Unify code paths for NCCL P2P and collectives scheduling#48649
AndyUB wants to merge 151 commits intoray-project:masterfrom
AndyUB:union-dev-1105

Conversation

@AndyUB
Copy link
Copy Markdown
Contributor

@AndyUB AndyUB commented Nov 8, 2024

Why are these changes needed?

This PR unifies the code paths for NCCL P2P and collectives. Before, scheduling for NCCL operations is done by splitting each node into three operations: READ, COMPUTE, and WRITE. This PR simplifies the logic by only keeping the compute node. To ensure scheduling still works, NCCL operations are converted into special types of system-created compute nodes.

This PR also allows overlapping NCCL collectives with computation.

NCCL P2P Refactoring

with InputNode() as inp:
  dag = actor1.foo.bind(inp)
  dag = dag.with_tensor_transport("nccl")
  dag = actor2.bar.bind(dag)

Before this PR, compiling this dag will result in a TorchTensorNcclChannel from foo to bar.
image

This PR adds a NcclSendNode after foo and a NcclRecvNode before bar. The TorchTensorNcclChannel now connects the two added nodes. Since foo and the send node are on the same actor, the channel from foo to the send node is an IntraProcessChannel. Same thing for the recv side.
image

Multiple Receivers
with InputNode() as inp:
  dag = actor1.foo.bind(inp)
  dag = dag.with_tensor_transport("nccl")
  dag = MultiOutputNode([actor2.bar.bind(dag), actor3.baz.bind(dag)])

In this case, the sender sends to two different receivers.
image
Only one NcclSendNode is created. One NcclRecvNode is created per receiver. Like before, there is only 1 TorchTensorNcclChannel.
image

Multiple Senders
with InputNode() as inp:
  branch1 = actor1.foo.bind(inp)
  branch1 = branch1.with_tensor_transport("nccl")
  branch2 = actor2.bar.bind(inp)
  branch2 = branch2.with_tensor_transport("nccl")
  dag = actor3.baz.bind(branch1, branch2)

The receiver receives from two senders.
image
1 NcclSendNode is created per sender. 1 NcclRecvNode is created per argument for the receiver. There are 2 different TorchTensorNcclChannels.
image

Overlap NCCL Collectives

This is done by prioritizing NCCL operations over non-NCCL operations when scheduling, i.e., if both some NCCL operations and some non-NCCL operations are ready to be added into the actors' execution schedules, NCCL operations are always added before the non-NCCL ones.

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Weixin Deng and others added 10 commits October 27, 2024 10:36
Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
@dengwxn
Copy link
Copy Markdown

dengwxn commented Nov 8, 2024

Looks great. Some more TODOs before an initial review as we discussed offline:

  1. Refactor all the [CL] and [TODO] in the code. They are mainly missing comments, unused code blocks, branches to be merged, variable and function names to be renamed, etc.
  2. Introduce a special op node for NCCL_Collective similar to the current NCCL_READ and NCCL_WRITE, such that the COMPUTE node does not require NCCL.

cc @dengwxn

@dengwxn
Copy link
Copy Markdown

dengwxn commented Nov 8, 2024

@anyscalesam Could you help add a go badge to run more CI tests? Thanks!

@AndyUB AndyUB marked this pull request as ready for review November 8, 2024 18:49
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
This reverts commit 941cb73.

Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
@dengwxn
Copy link
Copy Markdown

dengwxn commented Nov 9, 2024

Introduce a special op node for NCCL_Collective similar to the current NCCL_READ and NCCL_WRITE, such that the COMPUTE node does not require NCCL.

After your attempt and a second thought, I think this might not be the best way to separate NCCL and non-NCCL ops by introducing another NCCL_Collective op. We can skip this and see what others think.

@dengwxn
Copy link
Copy Markdown

dengwxn commented Nov 9, 2024

As we discussed offline, we should remove all the NCCL_* op nodes, instead we should create system-level DAG nodes doing NCCL read/write. We will refactor based on this.

Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Copy link
Copy Markdown

@dengwxn dengwxn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass. Structure seems right. Will look into details later.

Copy link
Copy Markdown
Contributor

@stephanie-wang stephanie-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be made simpler. Try to think about how you can achieve the following:

  • _NCCLSendNode/_NCCLRecvNode should have the same interface as _CollectiveOperation
  • If the above is done properly, I believe we can get rid of most of the parts that need to differentiate between send/recv/collective. I.e. there should be only one requires_nccl flag instead of three, and there should only be on kind of DAG op node, a COMPUTE node.

@rkooo567 rkooo567 self-assigned this Nov 12, 2024
@stephanie-wang stephanie-wang self-assigned this Nov 12, 2024
@jcotant1 jcotant1 added the core Issues that should be addressed in Ray Core label Nov 15, 2024
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
AndyUB added 5 commits April 27, 2025 20:32
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
…dule_gpu

Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
@jjyao
Copy link
Copy Markdown
Contributor

jjyao commented Apr 29, 2025

@stephanie-wang @AndyUB do you want to continue working on this PR?

@stephanie-wang
Copy link
Copy Markdown
Contributor

@stephanie-wang @AndyUB do you want to continue working on this PR?

Yes, we're still working on this.

Copy link
Copy Markdown
Contributor

@stephanie-wang stephanie-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I think this is looking close to merge-able.

I'm a bit confused about a few things, though:

  • There are several different collective/p2p operation/node types added. Can you explain how each one is used, i.e. how do they reference each other and do we need all of them?
  • Is there any change in scheduling behavior compared to before?
  • Are there any unit tests that we can add? I.e. tests that don't need to create a full DAG and test the e2e execution.


def __init__(
self,
method_args: Tuple[_P2PSendNode],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the same structure as CollectiveOutputNode, where we create one actual _P2PNode and the send and recv nodes depend on the _P2PNode, via other_args_to_resolve?

Comment on lines +629 to +637
# Convert the abstract P2P operation from scheduling to the executable P2P
# send/recv operation.
if self.requires_nccl_read:
assert self.nccl_ch is not None
self.nccl_op = _P2PRecvOperation(self.nccl_ch)
elif self.requires_nccl_write:
assert self.nccl_ch is not None
self.nccl_ch.ensure_registered_as_writer()
self.nccl_op = _P2PSendOperation(self.nccl_ch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only need to do this conversion from abstract to executable operation for P2P operations and not for collective operations?

Comment on lines +737 to +739
if input_exc is not None and self.requires_nccl_write:
input_values = [input_exc]
input_exc = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code can be squashed into the following block.

method_args=(node,),
other_args_to_resolve={
PARENT_CLASS_NODE_KEY: send_actor_handle,
P2P_OPERATION_KEY: _P2POperation(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this get used?

(3, _DAGNodeOperationType.COMPUTE),
(3, _DAGNodeOperationType.WRITE),
]
w1_expected_schedule = [0, 1, 2, 5, 3, 4, 7, 6, 8]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment explaining what the expected schedule is.

Also, I assume there was no behavior change in this test?


@pytest.mark.skipif(not USE_GPU, reason="Skipping GPU Test")
@pytest.mark.parametrize("overlap_gpu_communication", [False, True])
def test_torch_tensor_nccl_overlap_collective(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments explaining what each test does.


@pytest.mark.skipif(not USE_GPU, reason="Skipping GPU Test")
@pytest.mark.parametrize("overlap_gpu_communication", [False, True])
def test_torch_tensor_nccl_overlap_send_future_across_actors(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test seems a bit complicated / unrelated compared to the stated goal? Is there a simpler test that can be run? Or a unit test?


@pytest.mark.skipif(not USE_GPU, reason="Skipping GPU Test")
@pytest.mark.parametrize("overlap_gpu_communication", [False, True])
def test_torch_tensor_nccl_overlap_same_future_multiple_waits(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test seems a bit complicated / unrelated compared to the stated goal? Is there a simpler test that can be run? Or a unit test?

Signed-off-by: Yuhan Ruan <andyubryh@gmail.com>
stephanie-wang pushed a commit that referenced this pull request May 20, 2025
…tions (#53007)

Given an input DAG of SPMD training strategies such as DDP, after DAG
compile, the first actor will generate different execution schedules
than others. This is due to the current scheduling policy, when there
are multiple ready operation nodes such as `actor1.compute` (non-NCCL)
and `actor4.collective` (NCCL, for actor1-4, there's only one collective
operation node that's eventually ready), the policy does not know actor1
has both the non-NCCL `actor1.compute` and the NCCL `actor4.collective`.
This leads to actor1 scheduling the `actor1.compute` first, and actor1-4
scheduling the `collective` next.

We update the policy to push all the collective operations nodes into
candidates when the last of them is ready. In the previous example,
actor1 will have both `actor1.compute` and `actor1.collective` as
candidates. In a DAG of SPMD strategies, all the actors pop either the
`compute` or the `collective` together.

We also update the policy to simply prioritize the NCCL operation node
over the non-NCCL. This will lead to NCCL operations to be scheduled as
soon as possible. It is safe to do so under the current settings of CUDA
streams in the system, because each NCCL read/write/collective stream
only allows one outstanding NCCL kernel at a time.

We add a test `test_collective_dag.py::test_exec_schedules_ddp` to
verify the generated schedules are identical across workers for the DDP
stragegy. Other tests are updated to reflect the changes of prioritizing
the NCCL operation node over the non-NCCL.

## Related issue number

<!-- For example: "Closes #1234" -->

This PR is part of #48649 planning to be merged incrementally.

---------

Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
@stephanie-wang stephanie-wang changed the title (WIP) [core][compiled graphs] Unify code paths for NCCL P2P and collectives scheduling [core][compiled graphs] Unify code paths for NCCL P2P and collectives scheduling May 28, 2025
@stephanie-wang stephanie-wang changed the title [core][compiled graphs] Unify code paths for NCCL P2P and collectives scheduling (WIP) [core][compiled graphs] Unify code paths for NCCL P2P and collectives scheduling May 28, 2025
stephanie-wang pushed a commit that referenced this pull request May 29, 2025
…3111)

This PR unifies the scheduling implementation for the NCCL P2P and
collective operation nodes. The logic remains the same: (1) P2P case:
When a NCCL send node is selected, its downstream NCCL recv nodes are
also selected; (2) Collective case: When a NCCL collective node is
selected, its corresponding NCCL collective nodes are also selected.
Previously, the NCCL P2P case was implemented by selecting the recv
nodes if a send node is detected, and the NCCL collective case was
implemented by maintaining a set of pending collective nodes.

We unify the implementation for both cases. Concretely, they both
maintain a set of (pending) synchronous nodes named `sync_idxs` and
`pending_sync_idxs`. The synchronous nodes denote the P2P send/recv
nodes or the collective nodes. The NCCL P2P/collective operation is
ready when `sync_idxs == pending_sync_idxs`.

Test cases are updated to reflect the use of synchronous nodes for both
NCCL P2P and collective nodes.

This PR is a follow-up of #53007. They are parts of #48649 planning to
be merged incrementally.
---------

Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
@github-actions
Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had
any activity for 14 days. It will be closed in another 14 days if no further activity occurs.
Thank you for your contributions.

You can always ask for help on our discussion forum or Ray's public slack channel.

If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@github-actions github-actions bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Jun 12, 2025
@github-actions
Copy link
Copy Markdown

This pull request has been automatically closed because there has been no more activity in the 14 days
since being marked stale.

Please feel free to reopen or open a new pull request if you'd still like this to be addressed.

Again, you can always ask for help on our discussion forum or Ray's public slack channel.

Thanks again for your contribution!

@github-actions github-actions bot closed this Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community core Issues that should be addressed in Ray Core go add ONLY when ready to merge, run all tests stale The issue is stale. It will be closed within 7 days unless there are further conversation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants