Skip to content

Refactor allreduce for supporting prefill case#2453

Closed
TennyWang1223 wants to merge 26 commits intomainfrom
refactor_ar
Closed

Refactor allreduce for supporting prefill case#2453
TennyWang1223 wants to merge 26 commits intomainfrom
refactor_ar

Conversation

@TennyWang1223
Copy link
Copy Markdown
Contributor

Motivation

Refactor the custom allreduce implementation to decouple its C++ layer from PyTorch and its Python-side IPC exchange from RCCL/gloo, making the module more portable and self-contained. Additionally, increase the max buffer size to support prefill workloads with larger tensors.

Technical Details

1. IPC buffer management refactoring
Introduce IPCBuffer and IPCBufferPool classes to encapsulate IPC buffer lifecycle. IPCBuffer abstracts over two allocation modes — uncached (hipMalloc) for synchronization metadata and cached (torch.empty) for D2D relay. IPCBufferPool manages named buffers and provides IPC handle exchange for both eager mode (pre-registered buffers) and graph mode (dynamically captured addresses).
2. Decouple C++ layer from torch::Tensor
All C++ interfaces in custom_all_reduce.cu, .cuh, and .h are changed from torch::Tensor parameters/return values to raw pointers (int64_t / void*), element counts, dtype codes, and explicit hipStream_t. The C++ code now compiles without linking libtorch. The Python layer extracts primitives via tensor.data_ptr(), tensor.numel(), tensor.dtype, and torch.cuda.current_stream().cuda_stream before calling into C++. The _is_weak_contiguous check is also moved to the Python side.
3. Replace RCCL/gloo-based IPC handle broadcast with TCP store
IPCBufferPool._gather_ipc_meta now uses torch.distributed.TCPStore.set/get (a pure-TCP key-value store) instead of dist.broadcast_object_list (which routes through gloo collective backend). An assertion verifies the underlying store is TCPStore, ensuring no collective communication backend is involved. store.get() blocks until the key is available, providing natural barrier semantics.
4. Increase max_size to support prefill
max_size is raised from 128 MB to 1 GB to accommodate prefill-stage tensor sizes.
Files changed (8 files, +1042 / -691):

  • csrc/kernels/custom_all_reduce.cu — full rewrite, torch-free implementation
  • csrc/include/custom_all_reduce.h — raw pointer interfaces
  • csrc/include/custom_all_reduce.cuh — remove transitive torch dependency
  • csrc/include/rocm_ops.hpp — update pybind macro signatures
  • csrc/pybind/custom_all_reduce_pybind.cu — adjust includes
  • aiter/ops/custom_all_reduce.py — Python op stubs with raw pointer types
  • aiter/dist/device_communicators/custom_all_reduce.pyIPCBuffer, IPCBufferPool, TCPStore exchange, increased max_size
  • op_tests/multigpu_tests/test_car_rccl_latency.py — latency comparison test

Test Plan

  • Run test_custom_allreduce.py on 8× MI355 GPUs with both eager and graph modes to verify correctness (fp16, bf16).
  • Run test_car_rccl_latency.py on 8× MI355 GPUs to compare latency against RCCL allreduce.

Test Result

Allreduce correctness tests pass on 8× MI355. Latency comparison with RCCL:

Size Shape AITER (us) RCCL (us)
128 MB (8192, 8192) 788.7 867.3
256 MB (16384, 8192) 1472.8 1535.0
512 MB (32768, 8192) 2841.8 2872.5
1 GB (65536, 8192) 5547.2 5545.2

AITER custom allreduce matches or outperforms RCCL across all tested sizes on MI355.

Submission Checklist

root added 3 commits March 24, 2026 09:50
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
@TennyWang1223 TennyWang1223 requested review from a team and valarLip March 24, 2026 10:12
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2453 --add-label <label>

root added 5 commits March 25, 2026 03:58
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
@TennyWang1223
Copy link
Copy Markdown
Contributor Author

Support aiter tensor. Modified the C++ interface where input and output used raw pointers as parameters, changing it to use aiter tensor as parameters. Class pointers and IPCHandle pointers remain unchanged.

@TennyWang1223
Copy link
Copy Markdown
Contributor Author

MI300 test result

Size Shape AITER (us) RCCL (us)
128 MB (8192, 8192) 974.4 841.2
256 MB (16384, 8192) 1910.3 1598.7
512 MB (32768, 8192) 3792.6 3132.2
1 GB (65536, 8192) 6120.2 6126.5

MI308 test result

Size Shape AITER (us) RCCL (us)
128 MB (8192, 8192) 1056.7 949.5
256 MB (16384, 8192) 2059.3 1743.2
512 MB (32768, 8192) 4075.7 3344.4
1 GB (65536, 8192) 6592.5 6598.0

It looks like medium-sized cases still need optimization on the gfx942.

@amd-ruitang3
Copy link
Copy Markdown
Contributor

move "torch.tesnor -> pybind aiter_tesnor_t" to dtypes.py

TennyWang1223 and others added 10 commits March 27, 2026 03:56
valarLip and others added 4 commits March 31, 2026 02:48
…_dim

Previously the fused allreduce+rmsnorm+quant kernels only supported
N=512/1024/2048/4096 via compile-time template dispatch. This made
models with other hidden_dim (e.g. GLM-5 N=6144, GPT-OSS N=2880)
fall back to the slower non-fused path.

Changes:
- Convert HIDDEN_DIM/BLOCK_SIZE from template parameter to runtime
  parameter in 1stage/2stage/split fusion kernels
- Use __launch_bounds__(1024,1) with runtime thread count
- Fix block_reduce for non-power-of-2 warp counts (round up
  reduce_width for shfl_xor correctness)
- Pad 1stage launch threads to WARP_SIZE multiples with active guard
- Use dynamic shared memory for 2stage kernel
- Generalize step2 dispatch (local_device_load_rmsnorm) to support
  any N where n_packs >= 64, removing n_bytes%1024 alignment requirement
- Replace silent printf errors with throw for unsupported shapes
- Add AITER_AR_1STAGE env override for benchmarking
- Improve test_fused_ar_rms.py: add error column, --test flag,
  multi-shape support, markdown summary table

Now supports any N that satisfies: N % pack_size == 0 and
N / pack_size <= 1024 (i.e. N <= 8192 for bf16).
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
samremes pushed a commit that referenced this pull request Mar 31, 2026
* fea(ar): refactor custom allreduce

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fea: support prefill

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* add latency cmp with rccl

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: remove ck in new kernel

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: ruff check

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: test script format

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: ruff check

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: pa_metadata macro err

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fea(car): support aiter tensor

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* fix: move pybind aiter tensor to dtypes.py

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* add aiter_tensor_module

* update

* update

* update

* update

* update

* update

* fix: fused_ar_rms gpt n=2880 case

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* [Kernel][Perf] Make allreduce fusion kernels support arbitrary hidden_dim

Previously the fused allreduce+rmsnorm+quant kernels only supported
N=512/1024/2048/4096 via compile-time template dispatch. This made
models with other hidden_dim (e.g. GLM-5 N=6144, GPT-OSS N=2880)
fall back to the slower non-fused path.

Changes:
- Convert HIDDEN_DIM/BLOCK_SIZE from template parameter to runtime
  parameter in 1stage/2stage/split fusion kernels
- Use __launch_bounds__(1024,1) with runtime thread count
- Fix block_reduce for non-power-of-2 warp counts (round up
  reduce_width for shfl_xor correctness)
- Pad 1stage launch threads to WARP_SIZE multiples with active guard
- Use dynamic shared memory for 2stage kernel
- Generalize step2 dispatch (local_device_load_rmsnorm) to support
  any N where n_packs >= 64, removing n_bytes%1024 alignment requirement
- Replace silent printf errors with throw for unsupported shapes
- Add AITER_AR_1STAGE env override for benchmarking
- Improve test_fused_ar_rms.py: add error column, --test flag,
  multi-shape support, markdown summary table

Now supports any N that satisfies: N % pack_size == 0 and
N / pack_size <= 1024 (i.e. N <= 8192 for bf16).

* fix: add param support_prefill in ar

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* fix: test_fused_ar_rms.py

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* fix: test_fused_ar_rms.py

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

---------

Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
Co-authored-by: root <root@hjbog-srdc-24.amd.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
Co-authored-by: amd-ruitang3 <rui.tang2@amd.com>
Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com>
daydayup-lh pushed a commit that referenced this pull request Apr 1, 2026
* fea(ar): refactor custom allreduce

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fea: support prefill

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* add latency cmp with rccl

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: remove ck in new kernel

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: ruff check

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: test script format

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: ruff check

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fix: pa_metadata macro err

Signed-off-by: root <root@hjbog-srdc-24.amd.com>

* fea(car): support aiter tensor

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* fix: move pybind aiter tensor to dtypes.py

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* add aiter_tensor_module

* update

* update

* update

* update

* update

* update

* fix: fused_ar_rms gpt n=2880 case

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* [Kernel][Perf] Make allreduce fusion kernels support arbitrary hidden_dim

Previously the fused allreduce+rmsnorm+quant kernels only supported
N=512/1024/2048/4096 via compile-time template dispatch. This made
models with other hidden_dim (e.g. GLM-5 N=6144, GPT-OSS N=2880)
fall back to the slower non-fused path.

Changes:
- Convert HIDDEN_DIM/BLOCK_SIZE from template parameter to runtime
  parameter in 1stage/2stage/split fusion kernels
- Use __launch_bounds__(1024,1) with runtime thread count
- Fix block_reduce for non-power-of-2 warp counts (round up
  reduce_width for shfl_xor correctness)
- Pad 1stage launch threads to WARP_SIZE multiples with active guard
- Use dynamic shared memory for 2stage kernel
- Generalize step2 dispatch (local_device_load_rmsnorm) to support
  any N where n_packs >= 64, removing n_bytes%1024 alignment requirement
- Replace silent printf errors with throw for unsupported shapes
- Add AITER_AR_1STAGE env override for benchmarking
- Improve test_fused_ar_rms.py: add error column, --test flag,
  multi-shape support, markdown summary table

Now supports any N that satisfies: N % pack_size == 0 and
N / pack_size <= 1024 (i.e. N <= 8192 for bf16).

* fix: add param support_prefill in ar

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* fix: test_fused_ar_rms.py

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

* fix: test_fused_ar_rms.py

Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>

---------

Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
Co-authored-by: root <root@hjbog-srdc-24.amd.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
Co-authored-by: amd-ruitang3 <rui.tang2@amd.com>
Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com>
hubertlu-tw added a commit to hubertlu-tw/sglang that referenced this pull request Apr 3, 2026
…used AR+RMSNorm

- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
  for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
  already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
  decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
  covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
  regression test for ROCm/aiter#2586.
- Add PR documentation with A/B test results (GSM8K +2.3pp, TPOT -3.7%).

Made-with: Cursor
hubertlu-tw added a commit to hubertlu-tw/sglang that referenced this pull request Apr 3, 2026
…used AR+RMSNorm

- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
  for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
  already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
  decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
  covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
  regression test for ROCm/aiter#2586.

Made-with: Cursor
@sunway513
Copy link
Copy Markdown
Collaborator

can we get this PR merged in? @TennyWang1223 cc @zufayu

@sunway513
Copy link
Copy Markdown
Collaborator

Hi @TennyWang1223sgl-project/sglang#23580 reports an HIP graph capture invalidation in AiterCustomAllreduce (helper kernel launched on an internal stream during capture → hipErrorStreamCaptureInvalidated → SIGABRT all 8 TP procs at decode replay).

Could you confirm whether this PR's refactor of custom_all_reduce.cu / .cuh already addresses that helper-kernel-on-internal-stream issue, or whether a follow-up commit would be needed before merge?

Tracking issue with full context: #2941 (target v0.1.14). Without this PR (or an equivalent fix), AITER allreduce stays disabled in SGLang production via PR sgl-project/sglang#23581.

Thanks!

@TennyWang1223
Copy link
Copy Markdown
Contributor Author

This PR has already been merged into main. Due to a GitHub bug, it still shows as unmerged here. Therefore, the AITER code used when SGLang reported the bug should already include the changes from this PR, so it shouldn't help resolve the issue. I'll manually close this PR later. As for the bug sgl-project/sglang#23580, I'll go investigate the root cause now.
@sunway513

@sunway513
Copy link
Copy Markdown
Collaborator

Hi @TennyWang1223 — small follow-up to confirm intent. The PR shows as closed without merge in both the GitHub UI and the API:

state: CLOSED
mergedAt: null
closedAt: 2026-04-29T04:47:49Z

Branch refactor_ar (head 733e87fb) is currently diverged from main: ahead by 26, behind by 258, and grepping main's commit log for "#2453" or "Refactor allreduce" returns no matches.

Two possibilities:

  1. The squash-merge landed under a different commit message and the GitHub link to Refactor allreduce for supporting prefill case #2453 just got lost — in which case could you point me to the squash commit SHA so I can verify it's reachable from main?
  2. You closed without merge intentionally (e.g. the refactor became unnecessary after the ROCm 7.2.1 runtime fix) — in which case great, we'll close [Track v0.1.14] Fix AiterCustomAllreduce HIP graph capture invalidation (SGLang #23580) #2941 with that conclusion and update the v0.1.13-rc1 release notes accordingly.

Either way is fine, just want to make sure downstream consumers reading the closed PR get the right signal. Thanks!

@TennyWang1223
Copy link
Copy Markdown
Contributor Author

● The PR was actually merged via squash-merge as commit 8cfe5e281 ("Refactor allreduce for supporting prefill case (#2453)"), authored on 2026-04-01. The closed-without-merge state in the PR card appears to be a GitHub
UI artifact.

Verification (terminal output attached):

$ git fetch origin
$ git log origin/main --oneline | grep "#2453"
$ git merge-base --is-ancestor 8cfe5e281 origin/main && echo "on main"
image

See also: 8cfe5e281 — the main branch tag is shown on the commit page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants