I am trying to implement a one-shot allreduce in sglang. you can see my code in this PR.
I want to use the algorithm used in allreduce_bench.py. To fit everything in sgl-kernel, I rewrite the API in cpp. I can pass the correctness test but the performance will drop a lot.
here is the performance table in A100.
| msg_size |
torch eager time |
msccl eager time |
msccl graph time |
pynccl graph time |
| 2.0 KiB |
90.7264 |
46.592 |
24.2074 |
44.2675 |
| 4.0 KiB |
47.616 |
29.5936 |
22.8352 |
44.6669 |
| 8.0 KiB |
47.8208 |
30.72 |
23.6544 |
47.4112 |
| 16.0 KiB |
49.664 |
32.256 |
25.6717 |
48.0051 |
| 32.0 KiB |
47.8208 |
35.1232 |
29.5014 |
49.0291 |
| 64.0 KiB |
52.4288 |
42.0864 |
36.567 |
55.5418 |
| 128.0 KiB |
66.6624 |
57.856 |
51.3843 |
66.6214 |
| 256.0 KiB |
81.92 |
86.528 |
79.1142 |
83.2717 |
| 512.0 KiB |
106.291 |
133.734 |
132.27 |
93.9725 |
| 1.0 MiB |
117.965 |
264.704 |
236.073 |
121.631 |
While using allreduce_bench.py, I can get this table. I wonder if there is something I need to know to improve the performance.

I am trying to implement a one-shot allreduce in sglang. you can see my code in this PR.
I want to use the algorithm used in allreduce_bench.py. To fit everything in sgl-kernel, I rewrite the API in cpp. I can pass the correctness test but the performance will drop a lot.
here is the performance table in A100.
While using allreduce_bench.py, I can get this table. I wonder if there is something I need to know to improve the performance.