Support collective matmul optimization in mp#8855
Conversation
tengyifei
left a comment
There was a problem hiding this comment.
question: do you think it's possible to test that collective matmul is actually triggered? maybe we could look for the "decomposed_reduce_scatter_while" etc signatures in the optimized HLO from the XLA dump?
Thanks for you review, Yifei! I need to wait for tomorrow's libtpu nightly to enable collective matmul. |
150779b to
676f5b8
Compare
Gotcha. I think it's useful to add a test once we roll in a new libtpu. Otherwise it's hard to tell if this worked. LGTM otherwise. |
The problem is that collective-matmul related operations only appear after the optimization after TPU compiler. We can not use the utility functions |
|
@tengyifei hit in CI test. What I haved updated today is changing the libtpu version. Is there anything else I should do? |
| _TPU_V5P = "v5p" | ||
| _TPU_V6E = "v6e" |
There was a problem hiding this comment.
Drive by question: is there a way to do this without hard coding for the different chips? Is it possible to get the information on each <chip>_create_replica_groups casein a way that doesn't require hard coding for each case?
There was a problem hiding this comment.
It's a very good question @pgmoka . JAX tried to do it but cannot find the best ring order automatically. It keeps using hard coding for different chips. https://github.com/jax-ml/jax/blob/39e8ee93b015372049c833bf45f49105026d9e8a/jax/_src/mesh_utils.py#L31-L36.
The ideal way should be the xla compiler's responsibility to find the best communication order. I was told the xla compiler team is working on it and decided not to spend to much time on that.
There was a problem hiding this comment.
Gotcha. That makes sense. Is there perhaps a GitHub Issue we can add here for future tracking?
Usually that means the C++ ABI version of PyTorch and PyTorch/XLA mismatches. I have no idea why this could happen though. I think both libraries should use C++11 ABI now. |
|
It could be a bad commit merged into PyTorch master I just hit retry |
|
Oh, we'd need to update the libtpu minor version as well as JAX. Otherwise the pip package resolution reverts back to a stable version of torch_xla which is incompatible with PyTorch nightly |
|
Pushed an update |
Ah! There's a trick. We could run the test with the I was hoping to add a small example but I failed to run your test locally with this error :/ |
|
But the problem is that there might be multiple dumped after-opt hlo graphs, usually it's difficult to know which is the correct one. It will be much better if we can get it in a programic way. |
Thanks, I will take a look at this. Weird to me because I didn't change all-reduce. |
Yes, in fact JAX supports it. I think maybe it requires some compilation options plumbing. But I was also hoping that the clearing dump dir could work out. Let me rebuild things to be sure.. |
|
Ok, another problem is that JAX needs to be updated to |
|
Ok, I've also discovered that the When I tried running the For now I've added these two tests to the TPU CI |
|
I put the collective matmul test in a standalone file.
I suggest we can address this in another PR. |
|
Ok with the latest commit I was able to run the test code with |
I'm sorry the name of the test might not be appropriate. Here I only intend to test the collective with the new added parameters which are necessary for collective matmul. To enable collective matmul, we should add matmul in the test with shape large enough. You might find the appropriate test here: https://github.com/pytorch/xla/blob/chengji/cm/test/torch_distributed/cm_perf.py. I didn't add the real collective-matmul in the PR because I'm thinking about it's an optimization of the underlying XLA compiler. |
|
I see. Do you think there are downsides to adding a large enough matmul in the Another approach could be testing the graph execution time of |
Because usually the ref implementation is on CPU. If the matmul is too large, the CI test might be too slow. |
|
Oh, that's a good point. I think it's still useful to add a performance test. But feel free to do it in a follow up PR. |
|
@tengyifei I forgot to say we need extra environment variable to run that example. |
This PR:
ENABLE_COLLECTIVE_MATMUL_IN_MPto turn on the xla config which is required for collective matmul optimization on TPU