Skip to content

[ShardedDDP] Sync buffers + small cleanup#112

Merged
blefaudeux merged 27 commits intomasterfrom
oss_sharded_ddp
Sep 29, 2020
Merged

[ShardedDDP] Sync buffers + small cleanup#112
blefaudeux merged 27 commits intomasterfrom
oss_sharded_ddp

Conversation

@blefaudeux
Copy link
Copy Markdown
Contributor

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 24, 2020
@blefaudeux
Copy link
Copy Markdown
Contributor Author

Long lived branch so rebased a couple of times against master, most of the commits listed above are unrelated

# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer.step()
new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
# assert new_eval.item() < output.item()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be LR and random values dependent, not a very good test so it was not actually being used

"""
assert self.module.training, "Cannot call reduce in eval"

def reduce_params(params: List[Parameter], params_rank: int) -> None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of these cases were never used here, I assumed that they came from a copy-paste where this function was more generic than just reducing grads

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it was me mis-named it before.

@blefaudeux
Copy link
Copy Markdown
Contributor Author

ping reviewers, @msbaines @min-xu-ai

Copy link
Copy Markdown
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Feel free to add a TODO to address my comments on potential perf optimization. sorry for the delay.

map(
lambda x: x.wait(),
map(
lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to combine some buffers and reduce the total number of broadcast, like what we do above for reduce_grad?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think it would be an interesting option, I have an old branch doing that which I should revive. There's a trade-off though because there are copies in that case, in practice since the broadcast became async I saw a big speed bump and wondered whether this bucketing strategy is still to be followed ? I'll add a TODO, good idea

"""
assert self.module.training, "Cannot call reduce in eval"

def reduce_params(params: List[Parameter], params_rank: int) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it was me mis-named it before.

@blefaudeux blefaudeux merged commit 79ded82 into master Sep 29, 2020
@blefaudeux blefaudeux deleted the oss_sharded_ddp branch September 29, 2020 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support broadcast_buffers in OssDdp

3 participants