[WIP] Make some hard-coded changes for activation sharding#6161
[WIP] Make some hard-coded changes for activation sharding#6161
Conversation
There was a problem hiding this comment.
Offline sync with @JackCaoG:
np.unique-> write it pure Python? Or just usetorch.unique? Same forall.requires_pjrt-> cache the env var once in the runtime init.- Is
num_devicesonly used for assertion? If so, can we skip this check in Dynamo code flow only? global_runtime_device_count-- can we also cache this?- To check for correctness, look at HLO sharding.
TODOs
- Make these changes and update the PR (by Friday or Monday). Make sure it does not break existing SPMD implementation and LLaMa 2 inference run (w/ activation sharding).
- And then check the correctness throughout next week.
cc @yeounoh
|
While I was making some of the mentioned fixes and re-running the tests and LLaMa 2, I was seeing bunch of complaints from Dynamo saying it can't track much of the portion in And also: Most importantly, I saw this error: Which made me doubt that the tracing of the dynamo custom op is happening correctly. After some digging, I believe the correct way of invoking this custom dynamo op should be directly through the But now we need to make some bigger changes in the I'll continue to look into it tomorrow. |
|
I started to wondering if it is easier just let dynamo treat the |
That could be one option -- we still would need to figure out how to deal with the return type of |
XLAShardedTensor is already a torch.Tensor, which is the same mechanism DTensor anchors on to be compatible with Dynamo. So I would expect or hope it to be captured and treated as a regular tensor. |
|
Coming from offline sync with @yeounoh:
|
|
Opened an issue upstream at pytorch/pytorch#115970. I'll also reach out in PyTorch Slack on Monday. |
|
This is completed with #6524, closing. |
[WIP] Make some hard-coded changes for activation sharding
TODOs
@xr.requires_pjrt