Skip to content

Support collective matmul optimization in mp#8855

Merged
yaochengji merged 10 commits intomasterfrom
chengji/enable-cm
Mar 21, 2025
Merged

Support collective matmul optimization in mp#8855
yaochengji merged 10 commits intomasterfrom
chengji/enable-cm

Conversation

@yaochengji
Copy link
Copy Markdown
Collaborator

This PR:

  • add the new environment variable ENABLE_COLLECTIVE_MATMUL_IN_MP to turn on the xla config which is required for collective matmul optimization on TPU
  • add channel_id and use_global_device_ids to xm.all_gather and xm.reduce_scatter, which is also required to enable collective matmul
  • add a utility function to find the best ring order for v5p x 8 and v6 x 8

Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: do you think it's possible to test that collective matmul is actually triggered? maybe we could look for the "decomposed_reduce_scatter_while" etc signatures in the optimized HLO from the XLA dump?

@yaochengji
Copy link
Copy Markdown
Collaborator Author

question: do you think it's possible to test that collective matmul is actually triggered? maybe we could look for the "decomposed_reduce_scatter_while" etc signatures in the optimized HLO from the XLA dump?

Thanks for you review, Yifei! I need to wait for tomorrow's libtpu nightly to enable collective matmul.

@tengyifei
Copy link
Copy Markdown
Collaborator

I need to wait for tomorrow's libtpu nightly to enable collective matmul.

Gotcha. I think it's useful to add a test once we roll in a new libtpu. Otherwise it's hard to tell if this worked. LGTM otherwise.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

Gotcha. I think it's useful to add a test once we roll in a new libtpu. Otherwise it's hard to tell if this worked. LGTM otherwise.

The problem is that collective-matmul related operations only appear after the optimization after TPU compiler. We can not use the utility functions torch_xla._XLAC._get_xla_tensors_hlo as other tests do.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

yaochengji commented Mar 20, 2025

@tengyifei hit

ImportError: /home/runner/.local/lib/python3.10/site-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN5torch4lazy13MetricFnValueEd

in CI test. What I haved updated today is changing the libtpu version. Is there anything else I should do?

Comment on lines +134 to +135
_TPU_V5P = "v5p"
_TPU_V6E = "v6e"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive by question: is there a way to do this without hard coding for the different chips? Is it possible to get the information on each <chip>_create_replica_groups casein a way that doesn't require hard coding for each case?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a very good question @pgmoka . JAX tried to do it but cannot find the best ring order automatically. It keeps using hard coding for different chips. https://github.com/jax-ml/jax/blob/39e8ee93b015372049c833bf45f49105026d9e8a/jax/_src/mesh_utils.py#L31-L36.

The ideal way should be the xla compiler's responsibility to find the best communication order. I was told the xla compiler team is working on it and decided not to spend to much time on that.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. That makes sense. Is there perhaps a GitHub Issue we can add here for future tracking?

@tengyifei
Copy link
Copy Markdown
Collaborator

in CI test. What I haved updated today is changing the libtpu version. Is there anything else I should do?

Usually that means the C++ ABI version of PyTorch and PyTorch/XLA mismatches. I have no idea why this could happen though. I think both libraries should use C++11 ABI now.

@tengyifei
Copy link
Copy Markdown
Collaborator

It could be a bad commit merged into PyTorch master

I just hit retry

@tengyifei
Copy link
Copy Markdown
Collaborator

Oh, we'd need to update the libtpu minor version as well as JAX. Otherwise the pip package resolution reverts back to a stable version of torch_xla which is incompatible with PyTorch nightly

-_libtpu_version = '0.0.11'
-_jax_version = '0.5.2'
-_jaxlib_version = '0.5.2'
+_libtpu_version = '0.0.12'
+_jax_version = '0.5.3'
+_jaxlib_version = '0.5.3'

@tengyifei
Copy link
Copy Markdown
Collaborator

Pushed an update

@tengyifei
Copy link
Copy Markdown
Collaborator

The problem is that collective-matmul related operations only appear after the optimization after TPU compiler. We can not use the utility functions torch_xla._XLAC._get_xla_tensors_hlo as other tests do.

Ah! There's a trick. We could run the test with the XLA_FLAGS="--xla_dump_to=SOME_TEMP_PATH" env var, and then the optimized HLO of each graph will be saved there. I think in the test itself, it could clear the dump dir before running the graph involving the collective matmul, then torch_xla.sync() which should synchronously compile, then check the contents of the dump directory. There would be a "[....].after_optimizations.hlo.txt" file which contains the post-optimization HLO text. c.f. http://shortn/_4vbbWu9SA1

I was hoping to add a small example but I failed to run your test locally with this error :/

RuntimeError: Bad StatusOr access: FAILED_PRECONDITION: during context [Unknown]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.

Failed after pipeline-start
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspaces/torch/pytorch/xla/test/test_mp_all_gather.py", line 181, in <module>
    torch_xla.launch(_mp_fn, args=())
  File "/usr/local/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 245, in launch
    xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: FAILED_PRECONDITION: during context [Unknown]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

http://shortn/_4vbbWu9SA1

But the problem is that there might be multiple dumped after-opt hlo graphs, usually it's difficult to know which is the correct one. It will be much better if we can get it in a programic way.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

I was hoping to add a small example but I failed to run your test locally with this error

Thanks, I will take a look at this. Weird to me because I didn't change all-reduce.

@tengyifei
Copy link
Copy Markdown
Collaborator

It will be much better if we can get it in a programic way.

Yes, in fact JAX supports it. I think maybe it requires some compilation options plumbing. But I was also hoping that the clearing dump dir could work out.

Let me rebuild things to be sure..

@tengyifei
Copy link
Copy Markdown
Collaborator

Ok, another problem is that JAX needs to be updated to 0.5.4. http://shortn/_V0ooFDrrSS has all the pin update steps.

@tengyifei
Copy link
Copy Markdown
Collaborator

Ok, I've also discovered that the tests/run_tests.sh in https://github.com/pytorch/xla/pull/8855/files#diff-f10e8a68326c09caad4d8f8f3c2479deb3d5e26bb6d3e23e9e038daeb5ee107b isn't run in TPU CI. It's only run on CPU and GPU CI. Therefore, we're not testing these mp stuff on TPU at all.

When I tried running the test_mp_reduce_scatter.py test on a tpu v6e-8, it fails with

F0320 22:59:52.430697  201496 tpu_layout_assignment.cc:517] Check failed: shape.IsArray() 

For now I've added these two tests to the TPU CI run_test.sh and we can see if it fails or not. Unfortunately I don't have expertise to debug this error.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

I put the collective matmul test in a standalone file.

sn't run in TPU CI. It's only run on CPU and GPU CI. Therefore, we're not testing these mp stuff on TPU at all.

I suggest we can address this in another PR.

@tengyifei
Copy link
Copy Markdown
Collaborator

Ok with the latest commit I was able to run the test code with XLA_FLAGS='--xla_dump_to=/tmp/tmpsy8qbu6i' python3 test_mp_collective_matmul.py and it created some optimized HLOs. However, I can't find a while op in the HLO. I only found all-reduce and reduce-scatter. This could mean that the collective matmul wasn't activated.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

However, I can't find a while op in the HLO. I only found all-reduce and reduce-scatter. This could mean that the collective matmul wasn't activated.

I'm sorry the name of the test might not be appropriate. Here I only intend to test the collective with the new added parameters which are necessary for collective matmul. To enable collective matmul, we should add matmul in the test with shape large enough.

You might find the appropriate test here: https://github.com/pytorch/xla/blob/chengji/cm/test/torch_distributed/cm_perf.py.

I didn't add the real collective-matmul in the PR because I'm thinking about it's an optimization of the underlying XLA compiler.

@tengyifei
Copy link
Copy Markdown
Collaborator

I see. Do you think there are downsides to adding a large enough matmul in the test_mp_collective_matmul.py, and checking that the post-optimization HLO indeed contains a decomposed reduce-scatter? The benefit is to e.g. catch a bad libtpu roll that disables the optimization incorrectly. Maybe you're worried that we're testing the implementation detail?

Another approach could be testing the graph execution time of cm_perf.py, which should be pretty consistent on the v4-8 CI (I suppose it could only check the running time if it detects it's running on v4-8). LMK if you'd like to do this in a follow-up PR.

@yaochengji
Copy link
Copy Markdown
Collaborator Author

Do you think there are downsides to adding a large enough matmul

Because usually the ref implementation is on CPU. If the matmul is too large, the CI test might be too slow.

@tengyifei
Copy link
Copy Markdown
Collaborator

Oh, that's a good point. I think it's still useful to add a performance test. But feel free to do it in a follow up PR.

@yaochengji yaochengji merged commit 2ad5b5d into master Mar 21, 2025
@yaochengji
Copy link
Copy Markdown
Collaborator Author

@tengyifei I forgot to say we need extra environment variable to run that example.

ENABLE_COLLECTIVE_MATMUL_IN_MP=1 DISABLE_NUMERIC_CC_TOKEN=1 python test/torch_distributed/cm_perf.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants