Skip to content

Implement autograd functions for c10d communication operations #40762

Closed
emcastillo wants to merge 6 commits intopytorch:masterfrom
emcastillo:dist-backward
Closed

Implement autograd functions for c10d communication operations #40762
emcastillo wants to merge 6 commits intopytorch:masterfrom
emcastillo:dist-backward

Conversation

@emcastillo
Copy link
Copy Markdown
Collaborator

@emcastillo emcastillo commented Jun 30, 2020

Closes #40702, Fixes #40690

Currently wip. But I would appreciate some feedback. Functions should be double-differentiable.

Contrary to https://github.com/pytorch/pytorch/blob/b35cdc5200af963e410c0a25400fd07f30b89bca/torch/nn/parallel/_functions.py
This PR generates list of tensors instead of aggregating the received data in a single tensor. Is this behavior correct?

Thanks!

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Jun 30, 2020

💊 CI failures summary and remediations

As of commit 2c01091 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@emcastillo
Copy link
Copy Markdown
Collaborator Author

I have stopped working in this for a few days due to some other tasks, will resume asap :)

@mrshenli
Copy link
Copy Markdown
Contributor

mrshenli commented Sep 1, 2020

Hey @emcastillo, Do you plan to continue working on this, or is this task available for grab? No rush, just want to quickly check with you about the plan.

@emcastillo
Copy link
Copy Markdown
Collaborator Author

Hi @mrshenli ,
I am planning to finish this by next week :)

@emcastillo emcastillo changed the title [WIP] Implement autograd functions for c10d communication operations Implement autograd functions for c10d communication operations Sep 11, 2020
@emcastillo
Copy link
Copy Markdown
Collaborator Author

Hi @mrshenli can you please take a look? thanks!

@emcastillo
Copy link
Copy Markdown
Collaborator Author

I think I will support these calls in this PR and add the multigpu calls later.
I am a bit concerned about the correctness of the backward passes so some double check would be great here.

Thanks!

@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 11, 2020

Codecov Report

Merging #40762 (2c01091) into master (a72c6fd) will decrease coverage by 0.16%.
The diff coverage is 41.07%.

@@            Coverage Diff             @@
##           master   #40762      +/-   ##
==========================================
- Coverage   80.69%   80.52%   -0.17%     
==========================================
  Files        1905     1906       +1     
  Lines      206789   206901     +112     
==========================================
- Hits       166873   166613     -260     
- Misses      39916    40288     +372     

@VitalyFedyunin VitalyFedyunin added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 13, 2020
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the autograd point of view this looks good.
Some details about input modified inplace or just read to make sure your function is not "simplified away" but nothing crucial.

A few side questions though:

  • How is this going to be used?
  • I guess this is going to come but we definitely need tests for these. gradcheck and gradgradcheck should work fine here?
  • It would be nice if you could add type annotations for the user-facing API (and it would make the code easier to read as well because I was expecting src to be a Tensor in these functions but it is not).

Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
@emcastillo
Copy link
Copy Markdown
Collaborator Author

@albanD thanks for the awesome review! I am looking into all the comments and will update it soon :)

Comment thread test/distributed/test_c10d.py Outdated
@mrshenli
Copy link
Copy Markdown
Contributor

Sorry for dropping ball on this, will review today.

Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread torch/distributed/nn/_functions.py Outdated
Comment thread test/distributed/test_c10d.py Outdated
Comment thread test/distributed/test_c10d.py Outdated
@emcastillo
Copy link
Copy Markdown
Collaborator Author

I think failures are unrelated now

Copy link
Copy Markdown
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding this!!!

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read correctly that none of the Function ever change their inplace inplace right?

Comment thread test/distributed/test_c10d.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is this import actually necessary?

Comment thread test/distributed/test_c10d_spawn.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually need this import since you access them throw torch below?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.distributed.nn can't be accessed if its not directly imported :(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that expected? Or something we should fix?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea and left it as it was originally devised. I don't mind fixing it in this PR if you guys think its ok.
Or open a new PR so we can discuss this.
Thanks! will address all the comments during the weekend

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree this is beyond the scope of this PR. Just wandering if it was an oversight or a design choice :D @mrshenli ?

Comment thread torch/distributed/nn/functional.py Outdated
Comment thread torch/distributed/nn/functional.py Outdated
Comment thread torch/distributed/nn/functional.py Outdated
Comment thread torch/distributed/nn/functional.py Outdated
Comment thread torch/distributed/nn/functional.py Outdated
Comment thread torch/distributed/nn/functional.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any plans to consolidate this with the APIs in distributed_c10d.py?

@emcastillo
Copy link
Copy Markdown
Collaborator Author

torch.distributed.nn import fails on windows because it tries to directly import the rpc interface which is not currently supported.
I added a check to notify when this is not available and skip the tests. I think this should be addressed in a different PR.

This one should be ready for landing @mrshenli @albanD
Sorry for the delay and thanks!

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Jan 15, 2021

LGTM thanks for the update!

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Jan 15, 2021

@mrshenli you already have a diff for this so I'll let you do the land.
Let me know if you don't have time and I will commandeer the diff from you.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mrshenli merged this pull request in 233e4eb.

@emcastillo
Copy link
Copy Markdown
Collaborator Author

Thank you everyone!

HennerM added a commit to Emmi-AI/noether that referenced this pull request Mar 4, 2026
…96)

We provide convenience methods for some of the communication
collectives, one main reason for that is they work in both distributed
as well as non-distributed code. However in the none-distributed path we
still were doing some work such as moving the data from cpu to gpu or
vice versa. We can always avoid doing this work, reduce and gather ops
should be noops if we only have one process.

Another simplification is to remove the custom autograd function for
gather. PyTorch added support for gradients in all_gather some time ago:
pytorch/pytorch#40762, thus we can just use the
normal all_gather functions
kinggongzilla pushed a commit to Emmi-AI/noether that referenced this pull request Mar 5, 2026
…96)

We provide convenience methods for some of the communication
collectives, one main reason for that is they work in both distributed
as well as non-distributed code. However in the none-distributed path we
still were doing some work such as moving the data from cpu to gpu or
vice versa. We can always avoid doing this work, reduce and gather ops
should be noops if we only have one process.

Another simplification is to remove the custom autograd function for
gather. PyTorch added support for gradients in all_gather some time ago:
pytorch/pytorch#40762, thus we can just use the
normal all_gather functions
HowardWHW pushed a commit to HowardWHW/noether that referenced this pull request Mar 8, 2026
…mmi-AI#96)

We provide convenience methods for some of the communication
collectives, one main reason for that is they work in both distributed
as well as non-distributed code. However in the none-distributed path we
still were doing some work such as moving the data from cpu to gpu or
vice versa. We can always avoid doing this work, reduce and gather ops
should be noops if we only have one process.

Another simplification is to remove the custom autograd function for
gather. PyTorch added support for gradients in all_gather some time ago:
pytorch/pytorch#40762, thus we can just use the
normal all_gather functions
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…ch#40762)

Summary:
Closes pytorch#40702, Fixes pytorch#40690

Currently wip. But I would appreciate some feedback. Functions should be double-differentiable.

Contrary to https://github.com/pytorch/pytorch/blob/716b2a6d69546db2aa3e91cfd88e92350cf0bf46/torch/nn/parallel/_functions.py
This PR generates list of tensors instead of aggregating the received data in a single tensor. Is this behavior correct?

Thanks!

Pull Request resolved: pytorch#40762

Reviewed By: glaringlee

Differential Revision: D24758889

Pulled By: mrshenli

fbshipit-source-id: 79285fb4b791cae3d248f34e2aadb11c9ab10cce
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement autograd functions for c10d communication operations Using all_gather() in the forward pass in DDP throws RuntimeError

8 participants