Support activation sharding#53
Closed
wonjoo-wj wants to merge 1 commit intooptimize_spmd_shardingfrom
Closed
Conversation
miladm
reviewed
Mar 5, 2024
| device_ids = torch.arange(num_devices) | ||
| data_model_mesh = xs.Mesh(device_ids, (4, 1, 2)) | ||
| xs.mark_sharding(output, data_model_mesh, (0, 1, 2), use_dynamo_custom_op=True) | ||
| # num_devices = xr.global_runtime_device_count() |
There was a problem hiding this comment.
@wonjoolee95 can you add a discussion on the reason to keep output = self.wo(output) in line 204 when we want to do the dynamo_mark_sharding call after?
Collaborator
There was a problem hiding this comment.
This is the original implementation, and we are just adding the annotations. We have to do output projection as-is.
yeounoh
reviewed
Mar 20, 2024
| # custom python dynamo mark sharding | ||
| import torch_xla.experimental.dynamo_mark_sharding | ||
| device_ids = [0] | ||
| mesh_shape = [1, 1, 1] |
Collaborator
There was a problem hiding this comment.
wait, why is the mesh shape 1, 1, 1?
yeounoh
requested changes
Mar 20, 2024
Collaborator
yeounoh
left a comment
There was a problem hiding this comment.
This is for testing -- @wonjoolee95
- could you remove the "[WONJOO]" debugging probes, and make it look more formal if useful?
- Also, you have to change the mesh shape. Don't think we are sharding the activation in the current impl?
Collaborator
Author
|
Moving this over to #55 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
With pytorch/xla#6524, we can now support activation sharding.
This is an example of the changes we need to call the new dynamo python custom op.
Note that this requires pytorch/xla#6524 to be merged first. Tested locally, llama 2 inference with dynamo+spmd (with activation sharding) is successful: https://gist.github.com/wonjoolee95/a290a68f29c52bd395b16ae6df651531.
Also note that performance is not tested, so we need to find the optimal sharding strategy. This only tests the functionality.