[SymmMem] Multi-root tile reduction#164757
[SymmMem] Multi-root tile reduction#164757kwen2501 wants to merge 2 commits intogh/kwen2501/273/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164757
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit e17e8b3 with merge base a707042 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
weifengpy
left a comment
There was a problem hiding this comment.
looking good on UX part
| root = rank; | ||
| } | ||
| i++; | ||
| } |
There was a problem hiding this comment.
Should we check that root != world_size here?
There was a problem hiding this comment.
This implementation uses root == world_size to indicate that current rank does not need to reduce any tile. (Yet it still calls into this API to fulfill the collective requirement). You can see the "Note" above.
There was a problem hiding this comment.
we don't have a test for it though (root==world_size), do we?
There was a problem hiding this comment.
In test_multi_root_tile_reduce, when root_ratio is 2, we will exercise this case.
root_ratio=2 means only half of the ranks are root, the rest of ranks will provide root==world_size here to skip the reduction.
| root = rank; | ||
| } | ||
| i++; | ||
| } |
There was a problem hiding this comment.
we don't have a test for it though (root==world_size), do we?
| - `reduce_op` is the reduction operation to perform. Currently only "sum" is supported. | ||
| */ | ||
| TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now"); | ||
| TORCH_CHECK(out_tile.dtype() == at::kFloat, "Only float is supported"); |
There was a problem hiding this comment.
can you support at least BFloat16 also?
| for (auto& in_tile : in_tiles) { | ||
| TORCH_CHECK(in_tile.dtype() == at::kFloat, "Only float is supported"); | ||
| c10d::symmetric_memory::rendezvous(in_tile, group_name); | ||
| if (roots[i] == rank) { |
There was a problem hiding this comment.
we should check that roots[i] is valid (>=0 and < world_size)
| int nblocks = at::ceil_div( | ||
| out_tile.numel() * out_tile.element_size(), | ||
| (int64_t)THREADS_PER_BLOCK * 16); | ||
| nblocks = std::min(nblocks, 16); |
There was a problem hiding this comment.
why limit at 16? I think for cuda backend we limit at 24 at least, maybe we even need more for blackwell
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): Perform multiple tile reductions concurrently, with each tile reduced to a separate root. - The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API. - Currently supports NVLink SHARP scope only. Pull Request resolved: pytorch#164757 Approved by: https://github.com/weifengpy, https://github.com/fegin ghstack dependencies: pytorch#162243
|
In retrospective, it would have been great if the test could be guarded with: |
|
Were those large precision errors, as in grossly incorrect results, or some accuracy mismatch? If the latter, feel free to submit PRs guarding the tests, if the former, can we disable those ops on hardware that doesn't support it? |
Below is an example of mismatch: Mismatched elements: 16384 / 4194304 (0.4%) To execute this test, run the following from the base repo dir: In certain cases, these test case caused other problems on non-NVL platforms which I'm still looking into. |
|
FYI: We also noticed that certain H100 systems not supporting NVLINK SHARP would have illegal memory access when running the multi_root tile unit test. |
Stack from ghstack (oldest at bottom):
Perform multiple tile reductions concurrently, with each tile reduced to a separate root.
The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API.
Currently supports NVLink SHARP scope only.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci