Skip to content

[Gradient Compression] Allow BatchedPowerSGD to run vanilla allreduce for the first K iterations#51270

Closed
wayi1 wants to merge 4 commits intogh/SciPioneer/50/basefrom
gh/SciPioneer/50/head
Closed

[Gradient Compression] Allow BatchedPowerSGD to run vanilla allreduce for the first K iterations#51270
wayi1 wants to merge 4 commits intogh/SciPioneer/50/basefrom
gh/SciPioneer/50/head

Conversation

@wayi1
Copy link
Copy Markdown
Contributor

@wayi1 wayi1 commented Jan 28, 2021

Stack from ghstack:

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: D26077709

… for the first K iterations

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 28, 2021

💊 CI failures summary and remediations

As of commit adde81d (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 2/4 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_python_doc_build (1/1)

Step: "Doc Build and Push" (full log | diagnosis details | 🔁 rerun)

Jan 31 07:08:34 Makefile:38: recipe for target 'html' failed
Jan 31 07:08:34 
Jan 31 07:08:34 copying static files... ... done
Jan 31 07:08:34 copying extra files... done
Jan 31 07:08:34 dumping search index in English (code: en)... done
Jan 31 07:08:34 dumping object inventory... done
Jan 31 07:08:34 build finished with problems, 1 warning.
Jan 31 07:08:34 /var/lib/jenkins/workspace/docs/src/pytorch-sphinx-theme/pytorch_sphinx_theme/search.html:21: RemovedInSphinx30Warning: To modify script_files in the theme is deprecated. Please insert a <script> tag directly in your theme instead.
Jan 31 07:08:34   <p class="last">
Jan 31 07:08:34 /var/lib/jenkins/workspace/docs/src/pytorch-sphinx-theme/pytorch_sphinx_theme/search.html:24: RemovedInSphinx30Warning: To modify script_files in the theme is deprecated. Please insert a <script> tag directly in your theme instead.
Jan 31 07:08:34   </p>
Jan 31 07:08:34 Makefile:38: recipe for target 'html' failed
Jan 31 07:08:34 make: *** [html] Error 1
Jan 31 07:08:34 ++ code=2
Jan 31 07:08:34 ++ '[' 2 -ne 0 ']'
Jan 31 07:08:34 ++ set +x
Jan 31 07:08:34 =========================
Jan 31 07:08:34 /var/lib/jenkins/workspace/docs/source/notes/broadcasting.rst:6: WARNING: 'any' reference target not found: numpy.doc.broadcasting
Jan 31 07:08:34 =========================
Jan 31 07:08:34 Docs build failed. If the failure is not clear, scan back in the log
Jan 31 07:08:34 for any WARNINGS or for the line build finished with problems
Jan 31 07:08:34 (tried to echo the WARNINGS above the ==== line)

1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 Run tests 🔁 rerun

Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jan 28, 2021
wayi1 pushed a commit that referenced this pull request Jan 28, 2021
… for the first K iterations

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)

ghstack-source-id: 120400709
Pull Request resolved: #51270
Comment thread torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@wayi1 wayi1 requested a review from rohan-varma January 28, 2021 23:38
…a allreduce for the first K iterations"

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jan 29, 2021
… for the first K iterations

Pull Request resolved: #51270

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120617938

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)
Copy link
Copy Markdown
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the diff below this one was reverted, and this one is also failing tests: https://app.circleci.com/pipelines/github/pytorch/pytorch/266089/workflows/a036f791-01c8-4538-90eb-a4e40234b8c3/jobs/10492151. Do you need to resubmit the previous diff first?

@wayi1
Copy link
Copy Markdown
Contributor Author

wayi1 commented Jan 30, 2021

I think the diff below this one was reverted, and this one is also failing tests: https://app.circleci.com/pipelines/github/pytorch/pytorch/266089/workflows/a036f791-01c8-4538-90eb-a4e40234b8c3/jobs/10492151. Do you need to resubmit the previous diff first?

Created #51400 for the resubmission, and will submit that PR first.

@wayi1 wayi1 requested a review from rohan-varma January 30, 2021 02:02
…a allreduce for the first K iterations"

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)

[ghstack-poisoned]
…a allreduce for the first K iterations"

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jan 31, 2021
… for the first K iterations

Pull Request resolved: #51270

Similar to #50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 120725858

Differential Revision: [D26077709](https://our.internmc.facebook.com/intern/diff/D26077709/)
7) Computes M, which is approximately equal to PQ^T.
8) Truncates the input tensor to the original length.

Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have the default value for state.start_powersgd_iter documented in the docs for powerSGDState? Would be nice to ensure we have that, and maybe also specify the default start iteration here.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request has been merged in c080780.

@facebook-github-bot facebook-github-bot deleted the gh/SciPioneer/50/head branch February 5, 2021 15:21
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
… for the first K iterations (pytorch#51270)

Summary:
Pull Request resolved: pytorch#51270

Similar to pytorch#50973, allow the batched version to run vanilla allreduce for the first K iterations.

This may be useful if the batched version can be applied to some use cases where the accuracy requirement is not very strict.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression pytorch#47202
ghstack-source-id: 120725858

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

baseline: f248001754
batched PowerSGD: f246960752

The training time was reduced from 54m48s to 30m33s, and the accuracy is approximately the same: 44.21 vs 44.35

Reviewed By: rohan-varma

Differential Revision: D26077709

fbshipit-source-id: 6afeefad7a3fbdd7da2cbffb56dfbad855a96cb5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants