Skip to content

Gradient bucketing using a pre-defined bucket size cap#6417

Closed
amithrm wants to merge 10 commits intopytorch:masterfrom
amithrm:bucket_allreduce
Closed

Gradient bucketing using a pre-defined bucket size cap#6417
amithrm wants to merge 10 commits intopytorch:masterfrom
amithrm:bucket_allreduce

Conversation

@amithrm
Copy link
Copy Markdown
Contributor

@amithrm amithrm commented Jan 30, 2024

No description provided.

@JackCaoG JackCaoG requested a review from alanwaketan January 30, 2024 18:25
@alanwaketan
Copy link
Copy Markdown
Collaborator

Do you mind adding a test case?

@amithrm
Copy link
Copy Markdown
Contributor Author

amithrm commented Mar 4, 2024

Added the test case and rebased @JackCaoG @alanwaketan

Comment thread torch_xla/core/xla_model.py Outdated
grad_bytes = grad.numel() * grad.element_size()

# Gradient is larger than bucket_cap, don't bucketize
if grad_bytes > bucket_cap:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you want to specialize this case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the grad_bytes (already in the tensor) is larger than bucket cap, we send it straight away as a single tensor instead of bucketing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I understood the logic. But why? Combining it with the bucket introduce some problems?

Copy link
Copy Markdown
Collaborator

@jeffhataws jeffhataws Mar 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like you can get rid of this if statement (until continue), and the "if total > bucket_cap" should take care of this condition when bucket is empty.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with combining this with the rest is that the "buffer" allocated in the underlying runtime may not have enough space to fit this large tensor. The idea is to have a large buffer that can fit all the tensors. It can happen that total_bytes is just below the max allowed and this tensor if added to the bucket spills the maximum. Hence should go "alone" without bucketizing

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See your concerns now !! Fixed the code flow

"""
count = xrt_world_size()
if count > 1:
gradients = _fetch_gradients(optimizer)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the original behavior? And maybe use a flag to turn this feature on?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK..let me work on that

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should introduce an argument "bucket_cap_mb" that turns this on, instead of environmental variable? bucket_cap_mb=0 turns off bucketing and is the default?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread torch_xla/core/xla_model.py Outdated
# Bucketize till the total spills over
total += grad_bytes
if total > bucket_cap:
all_reduce(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check "if len(tensor_bucket):" because tensor_bucket can be empty at the start, when grad_bytes > bucket_cap.

torch_xla._XLAC._xla_wait_device_ops(devices=devices)


def bucketed_allreduce(gradients):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it similar to the original function all_reduce? How about all_reduce_bucketized?

Also, do you need to pass "groups" and "pin_layout" also?

Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please address other comments as well.

@amithrm amithrm force-pushed the bucket_allreduce branch from 777b97f to 31dd451 Compare May 28, 2024 20:20
@jeffhataws
Copy link
Copy Markdown
Collaborator

@JackCaoG do you know why the build failed with "ERROR: Error initializing RemoteModule"?

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented May 28, 2024

It is on a fork hence can't use remote cache but there was a bug that it still try to query the credintical. I think we fixed this issue error today, it should start building without cache. If you rebase the CI should start running.

@amithrm amithrm force-pushed the bucket_allreduce branch from 31dd451 to 05e2367 Compare May 29, 2024 02:56
@jeffhataws
Copy link
Copy Markdown
Collaborator

@JackCaoG looks like build is still failing for some reason after rebasing. Maybe another rebase is needed?

@JackCaoG
Copy link
Copy Markdown
Collaborator

The error still seems to be related with the fork. Let me grant both of you the write access, then you can open pr directly.

@JackCaoG
Copy link
Copy Markdown
Collaborator

OK I gave @amithrm write access

@jeffhataws
Copy link
Copy Markdown
Collaborator

Replaced by #7216 to avoid the build issues in CI testing.

@jeffhataws jeffhataws closed this Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants