Skip to content

[piecewise] Support custom all-reduce#14193

Closed
ByronHsu wants to merge 7 commits intomainfrom
byron/pcg-ca
Closed

[piecewise] Support custom all-reduce#14193
ByronHsu wants to merge 7 commits intomainfrom
byron/pcg-ca

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented Dec 1, 2025

Summary

Fix custom all-reduce support in piecewise CUDA graph by properly registering IPC buffers during graph capture.

Motivation

Custom all-reduce fails with piecewise CUDA graph:

$ python -m sglang.launch_server --model-path Qwen/Qwen3-8B --enable-piecewise-cuda-graph --tp 4

RuntimeError: an illegal memory access was encountered

Root Cause

1. Missing IPC buffer registration

In regular CudaGraphRunner, the graph_capture() context internally calls ca_comm.capture() which:

  • Sets _IS_CAPTURING = True during capture
  • Calls register_graph_buffers() after capture to register IPC addresses
    @contextmanager
    def capture(self):
        """
        The main responsibility of this context manager is the
        `register_graph_buffers` call at the end of the context.
        It records all the buffer addresses used in the CUDA graph.
        """
        try:
            self._IS_CAPTURING = True
            yield
        finally:
            self._IS_CAPTURING = False
            if not self.disabled:
                self.register_graph_buffers()

During capture, when _IS_CAPTURING = True, buffer addresses are collected:

    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);

After capture, register_graph_buffers() exchanges IPC handles between ranks and opens peer pointers. Without this, allreduce accesses invalid pointers → illegal memory access. (Details about how ipc handle work in custom all-reduce + cuda graph)

2. Incorrect warmup iterations (3 → 2)

With warmup_torch_compile() already running the model once, 3 additional iterations means:

  • Run 1: Warmup → Run 2: Capture → Run 3: Replay (during capture!)

Replay during capture tries to use IPC addresses before register_graph_buffers() is called (it's called after the context exits).

Modifications

  1. Add ca_comm.capture() context for proper IPC buffer registration:
        ca_comm = self.model_runner.tp_group.ca_comm
        maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()

        with enable_piecewise_cuda_graph(), maybe_ca_context:
  1. Reduce iterations from 3 to 2 to prevent replay during capture:
        # Run 2 times (not 3) since warmup_torch_compile() already triggered first_run.
        for _ in range(2):
  1. Remove disable_ca_comm and use_original_ca_comm workaround that previously disabled custom allreduce entirely during piecewise operations. ([piecewise] Refactor VLM to support input embed buffer and remove external embedder hack #14155)

Debugging Notes

# This client script sends two requests with length 1025. It can crash qwen vl.

import requests

url = "http://127.0.0.1:30000/generate"

seq_lens = [1025] * 2

for seq_len in seq_lens:
    data = {
        "input_ids": [0] * seq_len,
        "sampling_params": {
            "temperature": 0.0,
                "max_new_tokens": 32,
            },
    }

    response = requests.post(url, json=data)
    print(response.json())
  1. Qwen text model works well without crash
# Server
$ python -m sglang.launch_server --model Qwen/Qwen2.5-7B-Instruct --tp 4 --enable-piecewise-cuda-graph --disable-radix-cache
 
# Client 1 (gsm8k) 
 $ python few_shot_gsm8k.py  
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 17.5MB/s]                                                                           
100%|████████████████████████████████████████████████████████████████████████████| 200/200 [00:26<00:00,  7.53it/s]
Accuracy: 0.870
Invalid: 0.000
Latency: 26.877 s
Output throughput: 1226.728 token/s

# Client 2
$ python client.py 
...No crash
  1. Qwen VL model crashes if we run the above client, but does not crash in gsm8k (just luck)
# Server
$ python -m sglang.launch_server --model Qwen/Qwen2.5-VL-7B-Instruct --tp 4 --enable-piecewise-cuda-graph --disable-radix-cache

# Client 1 (gsm8k)

$ python few_shot_gsm8k.py 
100%|████████████████████████████████████████████████████████████████████████████| 200/200 [02:26<00:00,  1.37it/s]
Accuracy: 0.685
Invalid: 0.005
Latency: 146.104 s
Output throughput: 266.640 token/s

$ python client.py
# ... Crash

The error message is

[2025-12-09 02:28:12 TP0] Prefill batch, #new-seq: 1, #new-token: 1025, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-12-09 02:28:15] INFO:     127.0.0.1:35976 - "POST /generate HTTP/1.1" 200 OK
[2025-12-09 02:28:15 TP0] Prefill batch, #new-seq: 1, #new-token: 1025, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[rank1]:[E1209 02:28:15.281862353 ProcessGroupNCCL.cpp:2057] [PG ID 2 PG GUID 3 Rank 1] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /pytorch/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x80 (0x7fd07bf7cb80 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x11fb7 (0x7fd07c308fb7 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #2: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x50 (0x7fd01de3cbc0 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x68 (0x7fd01de4c298 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::Watchdog::runLoop() + 0x969 (0x7fd01de50499 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::Watchdog::run() + 0xdf (0x7fd01de5240f in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xdc253 (0x7fd1fb968253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #7: <unknown function> + 0x94ac3 (0x7fd1fe4e5ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #8: <unknown function> + 0x1268c0 (0x7fd1fe5778c0 in /usr/lib/x86_64-linux-gnu/libc.so.6)

I also use cuda core dump to pinpoint the error is in custom all reduce.

  1. Qwen coder crashed at init time
$  python3 -m sglang.launch_server --model-path Qwen/Qwen3-Coder-30B-A3B-Instruct --tp 4 --host 0.0.0.0 --enable-piecewise-cuda-graph
  
  File "/root/sglang/python/sglang/srt/models/qwen2_moe.py", line 596, in forward
    def forward(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.314", line 828, in forward
    submod_96 = self.submod_96(getitem_192, l_self_modules_layers_modules_47_modules_self_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_47_layer_communicator_post_attention_layernorm_parameters_weight_, getitem_4, l_self_modules_layers_modules_47_modules_mlp_modules_gate_parameters_weight_, l_self_modules_layers_modules_47_modules_mlp_modules_experts_parameters_w13_weight_, l_self_modules_layers_modules_47_modules_mlp_modules_experts_parameters_w2_weight_, l_self_modules_norm_parameters_weight_);  getitem_192 = l_self_modules_layers_modules_47_modules_self_attn_modules_o_proj_parameters_weight_ = l_self_modules_layers_modules_47_layer_communicator_post_attention_layernorm_parameters_weight_ = getitem_4 = l_self_modules_layers_modules_47_modules_mlp_modules_gate_parameters_weight_ = l_self_modules_layers_modules_47_modules_mlp_modules_experts_parameters_w13_weight_ = l_self_modules_layers_modules_47_modules_mlp_modules_experts_parameters_w2_weight_ = l_self_modules_norm_parameters_weight_ = None
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py", line 124, in __call__
    runtime_shape = args[self.sym_shape_indices[0]]
                         ~~~~~~~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range

[2025-12-09 04:23:30] Received sigquit from a child process. It usually means the child failed.
[1]    2112048 killed     python3 -m sglang.launch_server --model-path Qwen/Qwen3-Coder-30B-A3B-Instruc

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ByronHsu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang framework by enabling the use of custom allreduce operations within piecewise CUDA graphs. This change addresses previous compatibility issues that required disabling custom allreduce during graph capture, leading to a more unified and potentially more efficient execution flow for distributed models. The modifications simplify the codebase by removing explicit context managers for allreduce communication and refine the CUDA graph capture process to properly account for these operations.

Highlights

  • Custom Allreduce Integration: Custom allreduce operations are now fully supported and integrated within piecewise CUDA graphs, removing previous limitations that required explicit disabling during graph capture.
  • Simplified Communication Management: The explicit disable_ca_comm and use_original_ca_comm context managers have been removed, streamlining the handling of distributed communication during CUDA graph capture and execution.
  • Refined CUDA Graph Capture: The CUDA graph capture process has been updated to properly register IPC buffers for custom allreduce operations and the capture loop count has been adjusted for efficiency.
  • Multimodal Embedding Compatibility: The multimodal embedding routine no longer requires special handling to manage custom allreduce, indicating its seamless integration across the framework.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ByronHsu ByronHsu changed the title [piecewise] Support custom allreduce [piecewise] Support custom all-reduce Dec 1, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the handling of custom allreduce communication within the piecewise CUDA graph capture mechanism. The previous temporary context managers (disable_ca_comm, use_original_ca_comm) have been removed and replaced with a more integrated approach using nullcontext and ca_comm.capture(). This simplifies the code and improves maintainability. Additionally, an unnecessary warmup run during CUDA graph capture has been removed, enhancing efficiency. The changes also include removing unused imports and adjusting the mm_utils.py file to reflect the removal of get_tp_group and use_original_ca_comm.

@ByronHsu ByronHsu marked this pull request as draft December 1, 2025 06:50
@ByronHsu
Copy link
Copy Markdown
Collaborator Author

ByronHsu commented Dec 9, 2025

CUDA CoreDump Repro

  1. Start server
CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \                                                                                                                                                                                                        
CUDA_COREDUMP_SHOW_PROGRESS=1 \
CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \
CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" CUDA_LAUNCH_BLOCKING=1 python -m sglang.launch_server --model Qwen/Qwen2.5-VL-7B-Instruct --tp 4 --enable-piecewise-cuda-graph --disable-radix-cache
  1. Send Client
import requests

url = "http://127.0.0.1:30000/generate"

seq_lens = [1025] * 2

for seq_len in seq_lens:
    data = {
        "input_ids": [0] * seq_len,
        "sampling_params": {
            "temperature": 0.0,
                "max_new_tokens": 32,
            },
    }

    response = requests.post(url, json=data)
    print(response.json())
$ python client.py
  1. The server starts to dump
[2025-12-09 18:59:27 TP0] Prefill batch, #new-seq: 1, #new-token: 1025, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[18:59:28.674952] coredump: Starting GPU coredump generation
[18:59:28.706925] coredump: Starting GPU coredump generation
[18:59:28.707040] coredump: SM 1/132 is not used by any context
[18:59:28.707046] coredump: SM 2/132 is not used by any context
[18:59:28.707049] coredump: SM 3/132 is not used by any context
[18:59:28.707053] coredump: SM 4/132 is not used by any context
[18:59:28.707057] coredump: SM 5/132 is not used by any context
[18:59:28.707061] coredump: SM 6/132 is not used by any context
[18:59:28.707065] coredump: SM 7/132 is not used by any context
[18:59:28.707068] coredump: SM 8/132 is not used by any context
[18:59:28.707072] coredump: SM 9/132 is not used by any context
[18:59:28.707075] coredump: SM 10/132 is not used by any context
[18:59:28.707079] coredump: SM 11/132 is not used by any context
[18:59:28.707083] coredump: SM 12/132 is not used by any context
[18:59:28.707087] coredump: SM 13/132 is not used by any context
[18:59:28.707090] coredump: SM 14/132 is not used by any context
[18:59:28.707094] coredump: SM 15/132 is not used by any context
[18:59:28.707098] coredump: SM 16/132 is not used by any context
[18:59:28.707102] coredump: SM 17/132 is not used by any context
[18:59:28.707105] coredump: SM 18/132 is not used by any context
[18:59:28.707109] coredump: SM 19/132 is not used by any context
[18:59:28.707112] coredump: SM 20/132 is not used by any context
[18:59:28.707116] coredump: SM 21/132 is not used by any context
[18:59:28.707120] coredump: SM 22/132 is not used by any context
[18:59:28.707123] coredump: SM 23/132 is not used by any context
[18:59:28.707127] coredump: SM 24/132 is not used by any context
[18:59:28.707131] coredump: SM 25/132 is not used by any context
[18:59:28.707135] coredump: SM 26/132 is not used by any context
[18:59:28.707138] coredump: SM 27/132 is not used by any context
[18:59:28.707142] coredump: SM 28/132 is not used by any context
[18:59:28.707146] coredump: SM 29/132 is not used by any context
[18:59:28.707149] coredump: SM 30/132 is not used by any context
[18:59:28.707153] coredump: SM 31/132 is not used by any context
[18:59:28.707157] coredump: SM 32/132 is not used by any context
  1. Use cuda-gdb to find the cuda exception
*[byron/pcg-ca][~/sglang]$ cuda-gdb   
NVIDIA (R) cuda-gdb 12.9
Portions Copyright (C) 2007-2025 NVIDIA Corporation
Based on GNU gdb 14.2
Copyright (C) 2023 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.
Type "show copying" and "show warranty" for details.
This CUDA-GDB was configured as "x86_64-pc-linux-gnu".
Type "show configuration" for configuration details.
For bug reporting instructions, please see:
<https://forums.developer.nvidia.com/c/developer-tools/cuda-developer-tools/cuda-gdb>.
Find the CUDA-GDB manual and other documentation resources online at:
    <https://docs.nvidia.com/cuda/cuda-gdb/index.html>.

For help, type "help".
Type "apropos word" to search for commands related to "word".
(cuda-gdb) target cudacore /tmp/cuda_coredump_memx-cla-24-sr1.xpop.twttr.net.2161772.1765306609
Opening GPU coredump: /tmp/cuda_coredump_memx-cla-24-sr1.xpop.twttr.net.2161772.1765306609
[Current focus set to CUDA kernel 0, grid 155886, block (20,0,0), thread (288,0,0), device 3, sm 0, warp 9, lane 0]

CUDA Exception: Warp Illegal Address
The exception was triggered at PC 0x7f0cf1bcec70  void sglang::cross_device_reduce_2stage<__nv_bfloat16, 4>(sglang::RankData*, sglang::RankSignals, sglang::Signal*, __nv_bfloat16*, int, int)
#0  0x00007f0cf1bcecd0 in void sglang::cross_device_reduce_2stage<__nv_bfloat16, 4>(sglang::RankData*, sglang::RankSignals, sglang::Signal*, __nv_bfloat16*, int, int)
   <<<(36,1,1),(512,1,1)>>> ()
(cuda-gdb) 

@ByronHsu ByronHsu closed this Jan 7, 2026
@zhyncs zhyncs deleted the byron/pcg-ca branch January 8, 2026 00:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant