Skip to content

[c10d] Add hccl distributed backend to c10d data structures#146478

Closed
ankurneog wants to merge 2 commits intopytorch:mainfrom
ankurneog:c10d_add_hccl_to_backends
Closed

[c10d] Add hccl distributed backend to c10d data structures#146478
ankurneog wants to merge 2 commits intopytorch:mainfrom
ankurneog:c10d_add_hccl_to_backends

Conversation

@ankurneog
Copy link

@ankurneog ankurneog commented Feb 5, 2025

MOTIVATION

Intel Gaudi is an out-of-tree PyTorch accelerator having its own device /dispatch key hpu .
With this change we add entries for Gaudi's distributed backend hccl to the c10d Backend data structures.
This is to ensure that there is no naming conflict in case a new in-tree accelerator is introduced with the same backend name.

The Out-of-tree backends are registered calling

def register_backend(

Successful registration adds the backend name to the list :

backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI]

We are binding the process group creator constructs at run-time so if there are other distributed backend with the same device name they can safely add the device type to the dictionary

backend_capability: dict[str, list[str]] = {

And add another entry to the dictionary with the same backend name ( but different device name )

default_device_backend_map: dict[str, str] = {

In addition the out-of-tree devices can utilize the backend_list to check for successful backend registration eg: APIs like is_hccl_available

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146478

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b0e395c with merge base 4106aa3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Feb 5, 2025
@ankurneog ankurneog marked this pull request as draft February 5, 2025 12:10
@ankurneog ankurneog changed the title [c10d][intel gaudi] Add hccl distributed backend to c10d data structures [c10d] Add hccl distributed backend to c10d data structures Feb 5, 2025
@ankurneog ankurneog changed the title [c10d] Add hccl distributed backend to c10d data structures [WIP][c10d] Add hccl distributed backend to c10d data structures Feb 5, 2025
@AnantGulati
Copy link
Contributor

LGTM

@ankurneog ankurneog changed the title [WIP][c10d] Add hccl distributed backend to c10d data structures [c10d] Add hccl distributed backend to c10d data structures Feb 6, 2025
@ankurneog ankurneog marked this pull request as ready for review February 6, 2025 09:58
@ankurneog
Copy link
Author

@kwen2501 : Can you please help with the review ? Thanks

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2025
@ankurneog
Copy link
Author

@H-Huang , @kwen2501 : can you please help with the review . Thanks.

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, my main concern is that since Gaudi is an out of tree accelerator we are adding device specific logic that is not included or tested within pytorch. Ideally the user experience should be like:

import Gaudi
and all the relevant APIs for PTD should still work.

when importing Gaudi it should perform backend registration for the users like enforcing hccl is available and registering the backend. Therefore we shouldn't need code to live in the PyTorch repo if it can all be done out-of-tree. Does that make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding is_hccl_available() couldn't we just use the method defined below is_backend_available("hccl")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding is_hccl_available() couldn't we just use the method defined below is_backend_available("hccl")

Thanks for your comment. This was added based on customer feedback. Customers demand parity with other accelerators in particular Nvidia's , hence we needed to provide an API : is_hccl_available() to maintain cuda parity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the desire to have parity with cuda. However we have a goal to reduce coupling by designing more generic apis for vendor plugins. I'd also rather not add more of this type of function. I think it's reasonable for customers to use 'is_backend_available' given they already have to make a code change to swap cuda to hccl.

@ankurneog
Copy link
Author

Thanks for the PR, my main concern is that since Gaudi is an out of tree accelerator we are adding device specific logic that is not included or tested within pytorch. Ideally the user experience should be like:

import Gaudi and all the relevant APIs for PTD should still work.

when importing Gaudi it should perform backend registration for the users like enforcing hccl is available and registering the backend. Therefore we shouldn't need code to live in the PyTorch repo if it can all be done out-of-tree. Does that make sense?

@H-Huang : Thank you for your comment and suggestions. Actually that is exactly how the backend registration is happening now - as soon as the Gaudi library is imported , the registration happens automatically.

However, we see a risk that there could be potential conflicts introduced if an in-tree device is added with the same backend name. If you go over my changes - I have removed an assert, which checks if the backend name is already available in lists. Such cases might break our registration. With the entries we are ensuring that developers are aware that there exist a backend with name hccl. Even though Gaudi is an out-of-tree device it has it own name hpu registered for dispatch.

Regarding the API, customers demand parity with other accelerators in particular Nvidia's , hence we need to provide an API : is_hccl_available() to maintain cuda parity.

Hope I was able to convey the motivation behind the change. Let me know your views.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang do you happen to know if we support having multiple vendor backends listed here pointing to Custom, or is it going to cause a conflict next time someone adds one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second look, it doesn't look like we really use backend_type_map anymore. I think it was originally used so we could map the backend to a certain device to be used for barrier and object collectives but it looks like _get_object_coll_device handles that for us. So we should remove this data structure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue to track: #147044

@H-Huang
Copy link
Member

H-Huang commented Feb 13, 2025

Hi @ankurneog, thanks for the reply, I'm still not convinced most of the changes are needed

we see a risk that there could be potential conflicts introduced if an in-tree device is added with the same backend name. If you go over my changes

The PyTorch policy is to try to be as device-agnostic in our logic as possible. So the decision to add an in-tree device that happens to be the same name is unlikely. Even so, if you import your library it would still override the backend. If the default backend is "cpu:gloo,cuda:nccl,hpu:<non_hccl_backend>", you could overwrite this to use hccl when you import your library.

Regarding the API, customers demand parity with other accelerators in particular Nvidia's , hence we need to provide an API : is_hccl_available() to maintain cuda parity.

you could provide this as part of the gaudi library or even monkey patch this to torch.distributed when you import gaudi (though I wouldn't fully endorse that approach). Nonetheless, I don't think its a big ask for the user to change from is_nccl_available -> is_backend_available('hccl')

I have removed an assert, which checks if the backend name is already available in lists. Such cases might break our registration.

I'm okay with this and we could log a warning instead

@ankurneog
Copy link
Author

Hi @ankurneog, thanks for the reply, I'm still not convinced most of the changes are needed

we see a risk that there could be potential conflicts introduced if an in-tree device is added with the same backend name. If you go over my changes

The PyTorch policy is to try to be as device-agnostic in our logic as possible. So the decision to add an in-tree device that happens to be the same name is unlikely. Even so, if you import your library it would still override the backend. If the default backend is "cpu:gloo,cuda:nccl,hpu:<non_hccl_backend>", you could overwrite this to use hccl when you import your library.

Regarding the API, customers demand parity with other accelerators in particular Nvidia's , hence we need to provide an API : is_hccl_available() to maintain cuda parity.

you could provide this as part of the gaudi library or even monkey patch this to torch.distributed when you import gaudi (though I wouldn't fully endorse that approach). Nonetheless, I don't think its a big ask for the user to change from is_nccl_available -> is_backend_available('hccl')

I have removed an assert, which checks if the backend name is already available in lists. Such cases might break our registration.

I'm okay with this and we could log a warning instead

@H-Huang : i fully understand the concern here, but in absence of true device abstraction in the python frontend, how would we ensure that we really catch new addition to this list and or addition of the assert such as the one which i removed. Such entries would prevent our library from being loaded. That's the risk we are trying the mitigate here. Maybe we can clean these up when we have a solution to remove all these entries altogether? Let me know your view.

@ankurneog
Copy link
Author

@H-Huang : can you please provide your views on this? Thanks

@ankurneog
Copy link
Author

@H-Huang : Gentle reminder, can you share your views, so that we can have a closure here. Thank you.

@H-Huang
Copy link
Member

H-Huang commented Feb 24, 2025

Hi @ankurneog, I still don't think we need backend specific additions and the extension points are enough. If we were to add it, it would be needed to be tested as part of pytorch CI to prevent breakage.

You can still update the PR to only remove the assert, we can land that. We will try to remove the backend specific logic as well (e.g. #147635)

@ankurneog ankurneog force-pushed the c10d_add_hccl_to_backends branch from 9fc146e to 11f4b16 Compare February 27, 2025 04:30
@ankurneog
Copy link
Author

Hi @ankurneog, I still don't think we need backend specific additions and the extension points are enough. If we were to add it, it would be needed to be tested as part of pytorch CI to prevent breakage.

You can still update the PR to only remove the assert, we can land that. We will try to remove the backend specific logic as well (e.g. #147635)

@H-Huang : Sure, Sounds good, I have updated the PR to reflect the change. My only request is to put a mechanism in place to prevent further addition of backends and changes that may break out-of-tree devices. Thanks for your time to review the changes.

H-Huang
H-Huang previously approved these changes Feb 27, 2025
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes!

@H-Huang
Copy link
Member

H-Huang commented Feb 27, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 27, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@malfet
Copy link
Contributor

malfet commented Mar 2, 2025

@pytorchbot revert -m "This seems to break ROCM tests, see https://hud.pytorch.org/pytorch/pytorch/commit/dae3fbfe9720e83e7e81d41430fb5067221bbed7" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@ankurneog your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 2, 2025
@pytorch-bot pytorch-bot bot dismissed stale reviews from H-Huang and guangyey March 2, 2025 21:22

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@ankurneog
Copy link
Author

@H-Huang : Can you please help land this PR. the failure look unrelated.

@ankurneog ankurneog requested review from H-Huang and guangyey March 3, 2025 02:24
@H-Huang H-Huang added the ciflow/rocm Trigger "default" config CI on ROCm label Mar 3, 2025
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang
Copy link
Member

H-Huang commented Mar 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit to min-jean-cho/pytorch that referenced this pull request Mar 5, 2025
…146478)

 # MOTIVATION
Intel Gaudi is an out-of-tree PyTorch accelerator having its own device /dispatch key ```hpu``` .
With this change we add entries for Gaudi's distributed backend ```hccl``` to the c10d Backend data structures.
This is to ensure that there is no naming conflict in case a new in-tree accelerator is introduced with the same backend name.

The Out-of-tree backends are registered calling https://github.com/pytorch/pytorch/blob/fd0cd6a08f706b7bb1dedb296217b6441e4fb9ff/torch/distributed/distributed_c10d.py#L302

Successful registration adds the backend name to the list :
https://github.com/pytorch/pytorch/blob/fd0cd6a08f706b7bb1dedb296217b6441e4fb9ff/torch/distributed/distributed_c10d.py#L265

We are binding the process group creator constructs at run-time so if there are other distributed backend with the same device name they can safely add the device type to the dictionary

https://github.com/pytorch/pytorch/blob/fd0cd6a08f706b7bb1dedb296217b6441e4fb9ff/torch/distributed/distributed_c10d.py#L274

And add another entry to the dictionary with the same backend name ( but different device name )
https://github.com/pytorch/pytorch/blob/fd0cd6a08f706b7bb1dedb296217b6441e4fb9ff/torch/distributed/distributed_c10d.py#L268

In addition the out-of-tree devices can utilize the ```backend_list``` to check for successful backend registration  eg: APIs like ```is_hccl_available```

Pull Request resolved: pytorch#146478
Approved by: https://github.com/H-Huang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants