[amd] Add deterministic all-reduce kernel for AMD (ROCm)#15340
[amd] Add deterministic all-reduce kernel for AMD (ROCm)#15340HaiShaw merged 11 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @sunxxuns, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses the challenge of non-deterministic inference in distributed environments, particularly on AMD (ROCm) GPUs, by introducing a specialized deterministic all-reduce kernel. This kernel ensures that floating-point accumulation order is fixed, leading to reproducible results even with mixed-precision data types. The changes involve adding the HIP kernel, integrating it into the SGLang framework with new Python bindings and environment variables for control, and providing robust testing and benchmarking tools to demonstrate its effectiveness and performance characteristics. This enhancement is vital for applications requiring strict reproducibility in distributed model inference. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a deterministic all-reduce kernel for AMD GPUs, which is a valuable addition for ensuring deterministic inference. The changes are well-structured, touching the C++ kernel, Python bindings, and server-side logic to enable and use the new feature. The inclusion of new environment variables for controlling this behavior is a good design choice. The addition of benchmark and test files is also great for validation.
My review focuses on improving code clarity and fixing a couple of issues in the new benchmark script. Specifically, I've pointed out an unused parameter and an incorrect test implementation in benchmark_ar.py. I've also suggested refactoring a complex condition in parallel_state.py for better readability and pointed out an unused variable in custom_all_reduce.py.
Overall, this is a solid contribution. Once the suggested changes are addressed, the PR should be in good shape.
| return False | ||
|
|
||
|
|
||
| def reduce_scatter_then_all_gather(tensor, rank, world_size, custom_ar=None): |
There was a problem hiding this comment.
The custom_ar parameter is not used within the reduce_scatter_then_all_gather function. This can be misleading. Consider removing it to improve code clarity.
| def reduce_scatter_then_all_gather(tensor, rank, world_size, custom_ar=None): | |
| def reduce_scatter_then_all_gather(tensor, rank, world_size): |
| # Test custom all-reduce determinism (if available) | ||
| results_custom_ar = [] | ||
| latencies_custom_ar = [] | ||
| if custom_ar is not None: | ||
| for trial in range(num_trials): | ||
| # Clone the same input for each trial | ||
| inp_custom = base_input.clone() | ||
| inp_flat_custom = inp_custom.view(-1) | ||
|
|
||
| # Measure latency | ||
| torch.cuda.synchronize() | ||
| start = time.perf_counter() | ||
| reduce_scatter_then_all_gather(inp_flat_custom, rank, world_size, custom_ar=custom_ar) | ||
| torch.cuda.synchronize() | ||
| end = time.perf_counter() | ||
| latencies_custom_ar.append(end - start) | ||
|
|
||
| # Store checksum and first values (like test_ar.py) | ||
| checksum = inp_flat_custom.sum().item() | ||
| first_vals = inp_flat_custom[:5].clone() | ||
| results_custom_ar.append((checksum, first_vals)) |
There was a problem hiding this comment.
This test block for "custom all-reduce determinism" seems to be incorrectly implemented. It calls reduce_scatter_then_all_gather instead of a method from the custom_ar object. This means it's re-running the "reduce-scatter + all-gather" benchmark, not testing the custom all-reduce implementation.
To correctly test the non-deterministic custom all-reduce, you should call a method like custom_ar.custom_all_reduce(). Since this method is out-of-place, you'll need to adjust the surrounding code to handle the returned tensor.
For example:
# ...
if custom_ar is not None:
for trial in range(num_trials):
# Clone the same input for each trial
inp_custom = base_input.clone()
# Measure latency
torch.cuda.synchronize()
start = time.perf_counter()
# custom_all_reduce is out-of-place
result_custom = custom_ar.custom_all_reduce(inp_custom)
torch.cuda.synchronize()
end = time.perf_counter()
latencies_custom_ar.append(end - start)
# Store checksum and first values
if result_custom is not None:
inp_flat_custom = result_custom.view(-1)
checksum = inp_flat_custom.sum().item()
first_vals = inp_flat_custom[:5].clone()
results_custom_ar.append((checksum, first_vals))
# ...Note that custom_all_reduce might return None if it decides not to run, so you should handle that case. The current benchmark code seems to be based on a copy-paste from the reduce-scatter + all-gather test.
| # When viewed as inp.dtype, it has max_size / inp.element_size() elements | ||
| buffer_size_bytes = self.buffer.numel() # This equals max_size (bytes) | ||
| input_size_bytes = inp.numel() * inp.element_size() | ||
| buffer_size_elements = buffer_size_bytes // inp.element_size() |
| if not input_.is_contiguous() and not (input_.storage().nbytes() - input_.storage_offset() * input_.element_size() == input_.numel() * input_.element_size()): | ||
| basic_checks_passed = False | ||
| reasons.append("not weakly contiguous") |
There was a problem hiding this comment.
The logic to check for weak contiguity is complex and hard to read inline. To improve readability and maintainability, consider extracting this check into a local helper function within the GroupCoordinator class or the module.
For example:
def _is_weakly_contiguous(tensor: torch.Tensor) -> bool:
"""Check if a tensor is weakly contiguous."""
return tensor.is_contiguous() or \
(tensor.storage().nbytes() - tensor.storage_offset() * tensor.element_size() == tensor.numel() * tensor.element_size())
# ... inside all_reduce method ...
if not _is_weakly_contiguous(input_):
basic_checks_passed = False
reasons.append("not weakly contiguous")This would make the condition much clearer and the code easier to maintain. Since parallel_state.py is a core module that is imported by custom_all_reduce.py, you can't import is_weak_contiguous from custom_all_reduce_utils to avoid circular dependencies, so a local helper is appropriate here.
a037cbf to
7d698c4
Compare
|
/tag-and-rerun-ci |
|
Hi @sunxxuns , I am Thomas.
|
c7172d9 to
7e86161
Compare
thanks, just fixed the flag; faster is expected, as we are just using custom all reduce with fixed order here, so it's faster in small package size than the dist.all_reduce, but actuall slower than the default non-deterministic custom all reduce, which will show in e2e comparison. |
7e86161 to
c2ba1bf
Compare
Add a deterministic 1-stage all-reduce kernel for AMD GPUs that ensures consistent results across different batch sizes when using tensor parallelism. Key changes: - sgl-kernel: Add deterministic_all_reduce.hip with 1-stage kernel - parallel_state.py: Use deterministic kernel on AMD when --enable-deterministic-inference - server_args.py: Keep custom all-reduce enabled on AMD for deterministic inference - custom_all_reduce.py: Add deterministic_all_reduce method and dispatch logic The kernel uses fixed accumulation ordering (no atomics) to guarantee deterministic results. Performance is ~62% faster than reduce-scatter + all-gather. AMD only - CUDA path unchanged (still uses NCCL tree algorithm).
- Add MI350 (gfx950) installation instructions noting pre-built package must be uninstalled before source build - Add comprehensive ROCm/AMD Deterministic Inference section with: - Setup steps including aiter pre-compilation to avoid deadlock - Server launch command with SGLANG_PREFER_CUSTOM_ALLREDUCE_FOR_DETERMINISM - Test command for deterministic inference verification - Update test and benchmark docstrings with setup instructions
These .hip files are auto-generated by hipify from their .cu counterparts and should not be committed to the repository. They are generated at build time when building for ROCm/AMD GPUs.
Update comments and documentation to clarify that the deterministic all-reduce uses the existing 1-stage kernel (cross_device_reduce_1stage) which is inherently deterministic due to fixed accumulation ordering. This is NOT a reduce-scatter + all-gather approach. Each GPU reads all data from all GPUs and reduces locally in a fixed order.
- Add SGLANG_USE_DETERMINISTIC_ALLREDUCE: disable deterministic AR while keeping other deterministic settings (default: true) - Add SGLANG_FORCE_1STAGE_ALLREDUCE: force 1-stage kernel without enabling other deterministic settings (for testing) - Add [AR] prefixed logging to show which all-reduce implementation and call path is being used (Aiter vs sglang, deterministic vs default)
c2ba1bf to
d54f42b
Compare
* 'main' of https://github.com/sgl-project/sglang: (136 commits) fix: unreachable error check in retraction (sgl-project#15433) [sgl-kernel] chore: update deepgemm version (sgl-project#13402) [diffusion] multi-platform: support diffusion on amd and fix encoder loading on MI325 (sgl-project#13760) [amd] Add deterministic all-reduce kernel for AMD (ROCm) (sgl-project#15340) [diffusion] refactor: refactor _build_req_from_sampling to use shallow_asdict (sgl-project#13782) Add customized sampler registration (sgl-project#15423) Update readme (sgl-project#15425) Fix Mindspore model import warning (sgl-project#15287) [Feature] Xiaomi `MiMo-V2-Flash` day0 support (sgl-project#15207) [diffusion] profiling: add bench_serving.py and VBench (sgl-project#15410) [DLLM] Fix dLLM regression (sgl-project#15371) [Deepseek V3.2] Fix Deepseek MTP in V1 mode (sgl-project#15429) chore: update CI_PERMISSIONS (sgl-project#15431) [DLLM] Add CI for diffusion LLMs (sgl-project#14723) Support using different attention backend for draft decoding. (sgl-project#14843) feat(dsv32): better error handling for DeepSeek-v3.2 encoder (sgl-project#14353) tiny fix lint on main (sgl-project#15424) multimodal: precompute hash for MultimodalDataItem (sgl-project#14354) [AMD] Clear pre-built AITER kernels and warmup to prevent segfaults and test timeouts (sgl-project#15318) [Performance] optimize NSA backend metadata computation for multi-step speculative decoding (sgl-project#14781) ...
…#15340) Co-authored-by: Thomas Wang <1am9trash@gmail.com>
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
…#15340) Co-authored-by: Thomas Wang <1am9trash@gmail.com>
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
This patch aligns the wheel build helper to setup_rocm.py according to the two recent changes: (1) deterministic allreduce from sgl-project#15340 and (2) fast topk from sgl-project#15172.
…#15340) Co-authored-by: Thomas Wang <1am9trash@gmail.com>



Summary
This PR enables deterministic inference on AMD GPUs by using the 1-stage all-reduce kernel which is inherently deterministic (fixed accumulation order, no atomics).
Note: This is NOT a reduce-scatter + all-gather approach. The 1-stage kernel has each GPU read all data from all GPUs and reduce locally in a fixed order.
Key Changes
Kernel Implementation:
sgl-kernel/csrc/allreduce/deterministic_all_reduce.hip: Wrapper that forces 1-stage kernel for determinismsgl-kernel/csrc/common_extension_rocm.cc: Register deterministic opssgl-kernel/setup_rocm.py: Add kernel to ROCm buildsgl-kernel/python/sgl_kernel/allreduce.py: Add Python bindingsSGLang Integration:
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py: Adddeterministic_all_reducemethod, dispatch logicpython/sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py: Add deterministic ops for HIPpython/sglang/srt/distributed/parallel_state.py: Use deterministic kernel based on env flagpython/sglang/srt/server_args.py: Keep custom AR enabled for AMD deterministic modepython/sglang/srt/environ.py: AddSGLANG_USE_1STAGE_ALLREDUCEenv variableTests:
sgl-kernel/tests/test_amd_deterministic_custom_allreduce.py: Tests deterministic kernel consistencysgl-kernel/tests/test_amd_nccl_allreduce_determinism.py: Tests NCCL behavior (shows non-determinism)sgl-kernel/benchmark/bench_amd_deterministic_allreduce.py: Benchmarks all methodsEnvironment Variable
SGLANG_USE_1STAGE_ALLREDUCE--enable-deterministic-inferenceis on). Set to1to force enable,0to force disable.Usage
Basic deterministic inference
python -m sglang.launch_server \ --model-path Qwen/Qwen3-8B \ --tp 8 \ --attention-backend triton \ --enable-deterministic-inference \ --host 127.0.0.1 \ --port 30000Force 1-stage AR (for benchmarking, without other deterministic settings)
SGLANG_USE_1STAGE_ALLREDUCE=1 \ python -m sglang.launch_server \ --model-path Qwen/Qwen3-8B \ --tp 8 \ --attention-backend triton \ --host 127.0.0.1 \ --port 30000Use default Aiter AR even with deterministic inference
SGLANG_USE_1STAGE_ALLREDUCE=0 \ python -m sglang.launch_server \ --model-path Qwen/Qwen3-8B \ --tp 8 \ --attention-backend triton \ --enable-deterministic-inference \ --host 127.0.0.1 \ --port 30000Test determinism
Log Messages
Look for
[AR]prefixed logs to identify which all-reduce is being used:[AR] Using AiterCustomAllreduce (AMD default)- Aiter's implementation[AR] Using sglang CustomAllreduce (1-stage kernel)- sglang's 1-stage implementation[AR] All-reduce: 1-stage kernel (...)- Using 1-stage path[AR] All-reduce: default- Using default Aiter pathTest Plan
python sgl-kernel/tests/test_amd_deterministic_custom_allreduce.pypython3 -m sglang.test.test_deterministic --n-trials 50