Skip to content

[C10D] Avoid lazily creating P2P communicators#129147

Closed
wconstab wants to merge 10 commits intogh/wconstab/309/basefrom
gh/wconstab/309/head
Closed

[C10D] Avoid lazily creating P2P communicators#129147
wconstab wants to merge 10 commits intogh/wconstab/309/basefrom
gh/wconstab/309/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Jun 20, 2024

Stack from ghstack (oldest at bottom):

Users that opt-into eager initialization (enabled by passing device_id
to init_process_group) will now be able to take advantage of reusing
the existing communicator for the processgroup for send/recv ops rather
than creating new 2-rank communicators for every pair of ranks
performing send/recv.

Existing users not passing device_id to init_process_group will now get
a warning suggesting they do so, but they will still get the
functionality they have today, automatic creation of pair-wise
communicators.

Fixes #129140

Test plan

I didn't figure out a good way to unit test this change. (specifically, to make sure we avoid creating extra communicators when we opt-into the eager init path).

In the meantime, i've locally verified that a script that issues a send/recv gets the WARNING printed about the fallback path, and if I modify the script to either pass device_id=torch.device("cuda:{local_rank}") to init_process_group or issue an allreduce before the send/recv, in both cases the warning about the fallback path does not appear.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang

Differential Revision: D58842474

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jun 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129147

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 3 Unrelated Failures

As of commit 0958986 with merge base failed to retrieve merge base, please contact dev infra:

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jun 20, 2024
@wconstab wconstab requested review from H-Huang, chipturner, eqy, fegin, kwen2501 and pavanbalaji and removed request for kwen2501 June 20, 2024 17:43
Copy link
Copy Markdown
Contributor

@pavanbalaji pavanbalaji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

[ghstack-poisoned]
[ghstack-poisoned]
@wconstab wconstab added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 20, 2024
@wconstab
Copy link
Copy Markdown
Contributor Author

@wconstab has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@wconstab wconstab requested a review from kwen2501 June 21, 2024 00:48
Comment thread torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp Outdated
[ghstack-poisoned]
[ghstack-poisoned]
@wconstab wconstab requested a review from shuqiangzhang June 21, 2024 23:36
@wconstab
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • Distributed (mrshenli, pritamdamania87, zhaojuanmao, rohan-varma, wanchaol, ...)
  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@nvcastet
Copy link
Copy Markdown

nvcastet commented Jun 24, 2024

Is the goal for P2P ops to overlap with other communication ops in the same PG?

Correct.

For example the megatron-lm interleaved pipeline schedule will overlap send/receive ops targeting different peers using the same PG.

@pavanbalaji
Copy link
Copy Markdown
Contributor

Is the goal for P2P ops to overlap with other communication ops in the same PG?

Correct.

For example the megatron-lm interleaved pipeline schedule will overlap send/receive ops targeting different peers using the same PG.

You can use the same NCCL communicator (PyTorch PG) but issue different P2P operations on different streams. That won't be serialized.

@wconstab
Copy link
Copy Markdown
Contributor Author

wconstab commented Jun 25, 2024

I think the issue is that c10d manages the stream used for p2p ops, and its bundled together 1:1 with nccl communicator today.

I amended my RFC to account for this: #129140

@nvcastet do you think this amendment would solve your issue? I can try to make a PR to do this if so.

edit: i updated this PR to attempt to decouple nccl comm from nccl stream. It might be fairly straightforward to do this, but i need to re-examine it with fresh eyes and i assume i may have missed something.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 25, 2024
Users that opt-into eager initialization (enabled by passing device_id
to init_process_group) will now be able to take advantage of reusing
the existing communicator for the processgroup for send/recv ops rather
than creating new 2-rank communicators for every pair of ranks
performing send/recv.

Existing users not passing device_id to init_process_group will now get
a warning suggesting they do so, but they will still get the
functionality they have today, automatic creation of pair-wise
communicators.

When reusing an existing communicator, a dedicated nccl stream will
still be used for each pair of P2P ranks so that pair-wise comm ops can
overlap with each other rather than being serialized on a single stream
per PG.

Fixes #129140

ghstack-source-id: 3db38c6
Pull Request resolved: #129147
@nvcastet
Copy link
Copy Markdown

@pavanbalaji @wconstab
Unfortunately, to overlap 2 NCCL comm ops, you need at least those 2 conditions:

  • Use different NCCL communicators
  • Place ops on different CUDA streams

NCCL communicator will serialize the ops even if they are put on different streams (because they compete for the NCCL communicator internal resources: internal staging buffers etc...)

@nvcastet
Copy link
Copy Markdown

So to preserve overlap behavior, we would still need to create those p2p communicators in the PG.

The only other option I see (besides the obvious one to put this RFE/PR on the shelf for now) to avoid those extra communicators is to have an explicit config setting on the process group to disable the creation of p2p communicators (and documenting that unbatched p2p ops of this PG will be serialized with that setting).

As a side note, NCCL team is actively working on reducing communicator init cost, so I would not be surprise to see improvement in upcoming releases.

bool isSendRecvSelf,
std::optional<const std::string> streamKey) {
std::optional<const std::string> streamKey,
bool onlyCached) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on having a getOrCreateNCCLComm and then just a getNCCLComm? It's a bit unintutive that this function does both and splitting the behavior might be better than adding a bool to an already complicated function signature


// Note on keys
// devKey identifies this gpu device and is used for accessing a nccl
// Communicator for this PG per device p2pKey identifies a pair of ranks doing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing period/new line between device and p2pKey?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. lintrunner totally hosed me here.

auto ncclStream = ncclStreams_.at(p2pKey);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
syncStream(device, ncclEvents_[p2pKey], ncclStream);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the old logic this is conditionally the p2pKey or the devKey -- is it intentional to always use the p2p key now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes- it is intentional to always use the p2pkey for the stream, based on the wrong assumption that using the same comm but different stream would allow overlap between p2p ops involving different peers.

but i suspect i missed something- i probably should have kept this as devKey for batched-p2p ops and only made this p2pkey for true p2p ops.

@pavanbalaji
Copy link
Copy Markdown
Contributor

pavanbalaji commented Jul 22, 2024

@pavanbalaji @wconstab Unfortunately, to overlap 2 NCCL comm ops, you need at least those 2 conditions:

  • Use different NCCL communicators
  • Place ops on different CUDA streams

NCCL communicator will serialize the ops even if they are put on different streams (because they compete for the NCCL communicator internal resources: internal staging buffers etc...)

Hi @nvcastet - we should discuss this. It's not clear why NCCL needs to serialize point-to-point operations on the same communicator. I understand that collective operations need to be serialized, but p2p operations should be independent of each other. NCCL should be able to handle internal resources correctly in such cases. Is there a technical reason for this restriction or is it just an artifact of the current implementation? If it's an artifact of the current implementation, PyTorch shouldn't be working around that. We should fix it in NCCL.

@nvcastet
Copy link
Copy Markdown

nvcastet commented Jul 22, 2024

@pavanbalaji

It's not clear why NCCL needs to serialize point-to-point operations on the same communicator.

NCCL communicator will serialize ungrouped ops because they share internal resources (net buffers etc...).
For the megatron-lm use case mentioned early on we don't group p2p ops to get finer overlapping.

@pavanbalaji
Copy link
Copy Markdown
Contributor

@pavanbalaji

It's not clear why NCCL needs to serialize point-to-point operations on the same communicator.

NCCL communicator will serialize ungrouped ops because they share internal resources (net buffers etc...). For the megatron-lm use case mentioned early on we don't group p2p ops to get finer overlapping.

Hi @nvcastet - This seems to be overly restrictive and is different from what other communication libraries (such as MPI) provide. Creating a new communicator for every point-to-point pair that we need to talk to is very expensive with respect to number of resources used (and performance in some cases).

@nvcastet
Copy link
Copy Markdown

You only need to create a new communicator for pt-to-pt if you are going to overlap it with another NCCL Op.
That is the current semantics of the NCCL library which is what we need to look at for this PR.
I would encourage you to move the discussion to the NCCL repo by opening a discussion/RFE there so that the NCCL engineers can scope your proposal.

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the Stale label Sep 24, 2024
@github-actions github-actions Bot closed this Oct 24, 2024
@kwen2501 kwen2501 added no-stale and removed Stale labels Oct 25, 2024
@kwen2501
Copy link
Copy Markdown
Collaborator

kwen2501 commented Oct 25, 2024

Hi @nvcastet thanks for your comments, I'd like to follow up a bit.

  1. Does megatron use eager init (i.e. passing a device to the device_id of init_process_group) or lazy init?

For lazy init, we can keep the dedicated P2P comms -- they will be required anyway, because we cannot assume the "whole" comm is ready at the time P2P is called. For eager init, since we know that the whole comm is ready, we'd like to use the whole comm for P2P.

If megatron has been relying on lazy init (which is the traditional option), this change will not pose a perf regression for megatron. Does that make sense?

  1. Re relaxing the serialization in NCCL

I can understand why the serialization is needed, as you mentioned, intermediate buffers are not easy to schedule for sharing. Luckily, I think some recent NCCL advances may help to relax this serialization, in particular for P2P. Let's say zero-copy is enabled for P2P, be it network-based zero copy or GPU-GPU zero copy, these ops themselves will not need intermediate buffers, because data is directly fetched from user buffers. In this case, it would seem possible to allow multiple P2P ops to run on parallel streams? It seems to me that this may be even easier for network-based P2P because it may not even need to launch SMs in this case.

Cc: @wconstab @eqy @pavanbalaji

@nvcastet
Copy link
Copy Markdown

nvcastet commented Oct 30, 2024

  1. Does megatron use eager init (i.e. passing a device to the device_id of init_process_group) or lazy init?

They used to have just lazy init but migrated to eager init to leverage the NCCL comm split feature.

  1. Re relaxing the serialization in NCCL

I can understand why the serialization is needed, as you mentioned, intermediate buffers are not easy to schedule for sharing. Luckily, I think some recent NCCL advances may help to relax this serialization, in particular for P2P. Let's say zero-copy is enabled for P2P, be it network-based zero copy or GPU-GPU zero copy, these ops themselves will not need intermediate buffers, because data is directly fetched from user buffers. In this case, it would seem possible to allow multiple P2P ops to run on parallel streams? It seems to me that this may be even easier for network-based P2P because it may not even need to launch SMs in this case.

Agreed, if NCCL removes the serialization they perform for p2p ops sharing the same communicator, that would solve the issue. that would be great. For zero-copy to be beneficial and avoiding constant registrations, we would need stability of ptrs between iterations or use the CUDA graph feature, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request no-stale oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants