[ShardedDDP] Sync buffers + small cleanup#112
Conversation
|
Long lived branch so rebased a couple of times against master, most of the commits listed above are unrelated |
| # Check that the optimization process makes sense (ie. loss goes down for the same data) | ||
| optimizer.step() | ||
| new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel() | ||
| # assert new_eval.item() < output.item() |
There was a problem hiding this comment.
this would be LR and random values dependent, not a very good test so it was not actually being used
| """ | ||
| assert self.module.training, "Cannot call reduce in eval" | ||
|
|
||
| def reduce_params(params: List[Parameter], params_rank: int) -> None: |
There was a problem hiding this comment.
some of these cases were never used here, I assumed that they came from a copy-paste where this function was more generic than just reducing grads
There was a problem hiding this comment.
yeah, it was me mis-named it before.
|
ping reviewers, @msbaines @min-xu-ai |
min-xu-ai
left a comment
There was a problem hiding this comment.
Looks good. Feel free to add a TODO to address my comments on potential perf optimization. sorry for the delay.
| map( | ||
| lambda x: x.wait(), | ||
| map( | ||
| lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True), |
There was a problem hiding this comment.
does it make sense to combine some buffers and reduce the total number of broadcast, like what we do above for reduce_grad?
There was a problem hiding this comment.
yes, I think it would be an interesting option, I have an old branch doing that which I should revive. There's a trade-off though because there are copies in that case, in practice since the broadcast became async I saw a big speed bump and wondered whether this bucketing strategy is still to be followed ? I'll add a TODO, good idea
| """ | ||
| assert self.module.training, "Cannot call reduce in eval" | ||
|
|
||
| def reduce_params(params: List[Parameter], params_rank: int) -> None: |
There was a problem hiding this comment.
yeah, it was me mis-named it before.
Before submitting
What does this PR do?
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃