Gradient bucketing using a pre-defined bucket size cap#6417
Gradient bucketing using a pre-defined bucket size cap#6417amithrm wants to merge 10 commits intopytorch:masterfrom
Conversation
|
Do you mind adding a test case? |
11f466d to
fdb0f9e
Compare
|
Added the test case and rebased @JackCaoG @alanwaketan |
| grad_bytes = grad.numel() * grad.element_size() | ||
|
|
||
| # Gradient is larger than bucket_cap, don't bucketize | ||
| if grad_bytes > bucket_cap: |
There was a problem hiding this comment.
Curious why you want to specialize this case?
There was a problem hiding this comment.
if the grad_bytes (already in the tensor) is larger than bucket cap, we send it straight away as a single tensor instead of bucketing.
There was a problem hiding this comment.
Right, I understood the logic. But why? Combining it with the bucket introduce some problems?
There was a problem hiding this comment.
Yeah, looks like you can get rid of this if statement (until continue), and the "if total > bucket_cap" should take care of this condition when bucket is empty.
There was a problem hiding this comment.
The issue with combining this with the rest is that the "buffer" allocated in the underlying runtime may not have enough space to fit this large tensor. The idea is to have a large buffer that can fit all the tensors. It can happen that total_bytes is just below the max allowed and this tensor if added to the bucket spills the maximum. Hence should go "alone" without bucketizing
There was a problem hiding this comment.
See your concerns now !! Fixed the code flow
| """ | ||
| count = xrt_world_size() | ||
| if count > 1: | ||
| gradients = _fetch_gradients(optimizer) |
There was a problem hiding this comment.
Can we keep the original behavior? And maybe use a flag to turn this feature on?
There was a problem hiding this comment.
OK..let me work on that
There was a problem hiding this comment.
Maybe we should introduce an argument "bucket_cap_mb" that turns this on, instead of environmental variable? bucket_cap_mb=0 turns off bucketing and is the default?
| # Bucketize till the total spills over | ||
| total += grad_bytes | ||
| if total > bucket_cap: | ||
| all_reduce( |
There was a problem hiding this comment.
Need to check "if len(tensor_bucket):" because tensor_bucket can be empty at the start, when grad_bytes > bucket_cap.
| torch_xla._XLAC._xla_wait_device_ops(devices=devices) | ||
|
|
||
|
|
||
| def bucketed_allreduce(gradients): |
There was a problem hiding this comment.
Maybe name it similar to the original function all_reduce? How about all_reduce_bucketized?
Also, do you need to pass "groups" and "pin_layout" also?
alanwaketan
left a comment
There was a problem hiding this comment.
LGTM. Please address other comments as well.
|
@JackCaoG do you know why the build failed with "ERROR: Error initializing RemoteModule"? |
|
It is on a fork hence can't use remote cache but there was a bug that it still try to query the credintical. I think we fixed this issue error today, it should start building without cache. If you rebase the CI should start running. |
Summary: This pull request tries to unify all TORCH_LIBRARY definitions across torch_xla into one xla library. Test Plan: CI
|
@JackCaoG looks like build is still failing for some reason after rebasing. Maybe another rebase is needed? |
|
The error still seems to be related with the fork. Let me grant both of you the write access, then you can open pr directly. |
|
OK I gave @amithrm write access |
|
Replaced by #7216 to avoid the build issues in CI testing. |
No description provided.