Implement mark_sharding as a custom dynamo op#6524
Conversation
|
I wonder why we want to implement mark_sharding as a custom dynamo op |
This is a follow-up PR to #6161, which originally aimed to support Dynamo + SPMD activation sharding (i.e. allow |
|
To reviewers, I have some refactoring I want to do here based on my offline sync with Jack and Jiewen. Please don't review this yet, I'll tag you here once it's ready for review (ETA tomorrow). Thanks! |
|
Based on our conversation yesterday, we ideally want to keep the same existing cc @JackCaoG |
|
|
|
By adding The error we get is: Full error logs is at https://gist.github.com/wonjoolee95/937575e2f5498a2179554974e0a52264. This error is similar we saw in our prior attempt to make |
can you print the |
|
Posting an update after offline sync with Jack: so the issue with The current issue under discussion is from This error is happening at exactly this line in the new |
| mesh_shape = list(mesh.mesh_shape) | ||
| axis_names = str(mesh.axis_names) | ||
| partition_spec = '(1, 0)' | ||
| torch.ops.xla.dynamo_mark_sharding(linear.fc2.weight, device_ids, |
There was a problem hiding this comment.
shouldn't dynamo_mark_sharding be inside the linear? What we are trying to verify is whether custom op can be captured by dynamo.
There was a problem hiding this comment.
Yep, that is correct -- this was to see if the "normal" case can be successful with the new custom op. Added a new test to reflect that.
|
Due to Dynamo capturing a python native list type as an I can follow-up with PyTorch to see if we can do something to get this resolved, but for now, we can stick to having a separate custom op to enable dynamo mark sharding. With these changes, I can confirm that LLaMa 2 Dynamo + SPMD with activation sharding is passing as following: pytorch-tpu/llama#53. @JackCaoG, this should be ready to be reviewed now. Thanks! |
a9d827f to
d79ca73
Compare
|
|
ade4c02 to
f8de76e
Compare
f8de76e to
b7078af
Compare
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.distributed.spmd as xs | ||
| import torch_xla.debug.metrics as met | ||
| import torch_xla.experimental.dynamo_mark_sharding |
There was a problem hiding this comment.
we can move this under torch_xla.distirbuted.spmd.
|
Thanks for the review, Yeounoh. I'll go ahead and merge this for now and open a follow-up PR soon. |
Follow-up to #6161, this PR implementsmark_sharding as a custom dynamo op
The newly introduced custom op's function signature is:
As per https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md, the variables types that PyTorch accepts in registering function are very limited. It does not recognize
Tuple,Str[], or any custom types that we want formark_sharding. So when we define our custom opdynamo_mark_sharding, we had to make it a little hacky:Tuple[Union[Tuple, int, str, None]. So we just make this a string, then use Python'sast.literal_eval(which is an injection-safeeval) to convert the string equivalent version of the variables into their "real" types. Example of howast.literal_evalworks:Companion PR in LLaMa 2: pytorch-tpu/llama#53
TODO