Skip to content

Fix CUDA 13 cudaMemcpyBatchAsync segfault and restore hicache CI#23172

Closed
Kangyan-Zhou wants to merge 4 commits intosgl-project:mainfrom
Kangyan-Zhou:cuda13_memcpy_hicache_restore
Closed

Fix CUDA 13 cudaMemcpyBatchAsync segfault and restore hicache CI#23172
Kangyan-Zhou wants to merge 4 commits intosgl-project:mainfrom
Kangyan-Zhou:cuda13_memcpy_hicache_restore

Conversation

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator

Motivation

Ports @yhyang201's fix from #23136 and follows up by re-enabling the hicache tests that were parked under test/manual/ while the CUDA 13 segfault was unresolved.

CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync (9 params → 8). The dlsym path in sgl-kernel/csrc/kvcacheio/transfer.cu was hard-coded to the CUDA 12.8 signature, so the CUDA stream argument landed in the wrong slot and the runtime segfaulted inside cuMemcpyBatchAsync_v2. The fix uses driver_version at runtime to dispatch to either the CUDA 12 or CUDA 13 signature, preserving binary portability.

With the segfault fixed, the hicache tests that were temporarily moved to test/manual/ in #23119 and the follow-on cu13 flake sweeps can run in CI again.

Modifications

  • sgl-kernel/csrc/kvcacheio/transfer.cu: runtime-version dispatch between the v12 and v13 cudaMemcpyBatchAsync signatures (ported from Fix segfault in cudaMemcpyBatchAsync on CUDA 13.0 #23136).
  • Move 7 hicache tests from test/manual/ back to test/registered/:
    • hicache/test_hicache_storage.py
    • hicache/test_hicache_storage_3fs_backend.py
    • hicache/test_hicache_storage_file_backend.py
    • hicache/test_hicache_storage_mooncake_backend.py (also restores the register_cuda_ci(est_time=236, suite="stage-b-test-2-gpu-large") call that was dropped on the way to manual)
    • hicache/test_hicache_storage_runtime_attach_detach.py
    • hicache/test_hicache_variants.py
    • 4-gpu-models/test_qwen35_hicache.py
  • Strips the "TODO: move back after fixed" docstrings added when the files were parked.

Accuracy Tests

Credit to @yhyang201 for the original H200 CU13 run (from #23136): all 192 tests in sgl-kernel/tests/test_kvcacheio.py pass; previously the suite segfaulted at test_transfer_kv_pf_direct (~37%).

Checklist

cc @yhyang201 @Fridge003 @alisonshao

Port PR sgl-project#23136 (Yuhao Yang): cudaMemcpyBatchAsync lost its failIdx
parameter in CUDA 13, so the dlsym-based call was passing the stream
handle at the wrong slot and segfaulting inside cuMemcpyBatchAsync_v2.
Use driver_version at runtime to dispatch to either the CUDA 12 or
CUDA 13 signature.

With the segfault fixed, move the 7 hicache tests that were parked
under test/manual in PR sgl-project#23119 and subsequent cu13 flake sweeps back
into test/registered so they run in CI again:

- hicache/test_hicache_storage.py
- hicache/test_hicache_storage_3fs_backend.py
- hicache/test_hicache_storage_file_backend.py
- hicache/test_hicache_storage_mooncake_backend.py
- hicache/test_hicache_storage_runtime_attach_detach.py
- hicache/test_hicache_variants.py
- 4-gpu-models/test_qwen35_hicache.py

TODO "move back after fixed" docstrings are stripped and the
register_cuda_ci call that was dropped from the mooncake backend test
on its way to manual is restored.

Co-Authored-By: Yuhao Yang <yhyang201@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses compatibility issues with CUDA 13.0 by dynamically selecting the correct signature for cudaMemcpyBatchAsync using dlsym. It also re-enables several HiCache tests that were previously disabled due to segmentation faults in CUDA 13 environments. Feedback suggests using the CUDA runtime version instead of the driver version to determine the function signature to avoid potential mismatches in containerized environments. Additionally, it is recommended to update the CUDA 13 function pointer signature to include const qualifiers for better type safety.

Comment thread sgl-kernel/csrc/kvcacheio/transfer.cu Outdated

// CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync.
// Use runtime version to select the correct signature for binary portability.
const bool use_v13_signature = driver_version >= 13000;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The signature of cudaMemcpyBatchAsync is determined by the version of the CUDA runtime library (libcudart), not the driver version. In environments where the driver is upgraded but the runtime remains on an older version (e.g., Docker containers with a fixed runtime but host-provided driver), using driver_version here may lead to a signature mismatch and potential segfault if the runtime is actually < 13.0. Consider using cudaRuntimeGetVersion to determine the correct signature for the runtime function obtained via dlsym.

Comment on lines +921 to +929
using FnV13 = cudaError_t (*)(
void* const*,
const void* const*,
const size_t*,
size_t,
cudaMemcpyAttributes*,
size_t*,
size_t,
cudaStream_t);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FnV13 signature for cudaMemcpyBatchAsync in CUDA 13.0+ uses const for the attrs and attrIdxs parameters. While passing non-const pointers to const parameters is valid at the call site, updating the function pointer definition improves type safety and better reflects the actual API signature.

      using FnV13 = cudaError_t (*)(
          void* const*,
          const void* const*,
          const size_t*,
          size_t,
          const cudaMemcpyAttributes*,
          const size_t*,
          size_t,
          cudaStream_t);

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

When a USE_VENV=false runner had flashinfer-cubin installed
("already installed, keeping it"), `uv pip uninstall flashinfer-python`
left the flashinfer/data/ subdirectory behind (cubin files still owned
entries below it). The next `uv pip install -e python[dev,runai,tracing]`
then failed with:

  error: Failed to install: flashinfer_python-0.6.7.post3-py3-none-any.whl
    Caused by: failed to create directory
    `/usr/local/lib/python3.10/dist-packages/flashinfer/data`: File exists

Seen on stage-a-test-1-gpu-small in
https://github.com/sgl-project/sglang/actions/runs/24634237642/job/72027123887

Two-layer fix:

1. ci_install_dependency.sh (in-flight safeguard): right after the
   flashinfer uninstall step, if <site-packages>/flashinfer/ still
   exists, rm -rf it and force flashinfer-cubin to reinstall.
   `uv pip install -e python[...]` then resolves both flashinfer-python
   and flashinfer-cubin (both declared in pyproject.toml) and repopulates
   flashinfer/data/ cleanly. This makes the PR self-healing on its
   first run without depending on a prior job's post-cleanup.

2. ci_cleanup_venv.sh (post-job hygiene): the USE_VENV=false arm used
   to `exit 0` immediately. It now uninstalls the flashinfer trio and
   purges residual flashinfer/, flashinfer_cubin/, flashinfer_jit_cache/
   trees from system site-packages so the next job's runner starts
   clean even if the in-flight safeguard ever regresses. Cached wheels
   under ~/.cache/flashinfer-wheels/ keep the reinstall fast.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Addresses the review on sgl-project#23172:
sgl-project#23172 (comment)

cudaMemcpyBatchAsync is a libcudart (runtime) symbol; the ABI of the
function dlsym'd into this process is owned by the libcudart that's
actually loaded, not by the host's kernel driver. Dispatching on
cudaDriverGetVersion() breaks in the common container case where a
cu12 runtime is paired with a cu13-capable host driver: driver=13000
steers us to the 8-param v13 call, but the symbol resolves to v12
(9 params with failIdx), so the stream argument lands in a wrong slot
and we segfault — the exact crash this fix was supposed to prevent.

Reproduced on ion-user-9 with lmsysorg/sglang:dev (cu12.9 runtime):

    cudaDriverGetVersion()  = 13000
    cudaRuntimeGetVersion() = 12090
    v12 dispatch of dlsym'd symbol: cudaSuccess, exit 0
    v13 dispatch of dlsym'd symbol: Segmentation fault (core dumped)

Switching the signature-selection to cudaRuntimeGetVersion makes the
choice follow the loaded libcudart, which is what actually determines
the ABI. The existing cudaDriverGetVersion guard above is kept — it
remains the right knob for the capability check since cudaMemcpyBatch
requires a 12.8+ driver regardless of the runtime version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Kangyan-Zhou added a commit to yhyang201/sglang that referenced this pull request Apr 19, 2026
Addresses the review on sgl-project#23172:
sgl-project#23172 (comment)

cudaMemcpyBatchAsync is a libcudart (runtime) symbol; the ABI of the
function dlsym'd into this process is owned by the libcudart that's
actually loaded, not by the host's kernel driver. Dispatching on
cudaDriverGetVersion() breaks in the common container case where a
cu12 runtime is paired with a cu13-capable host driver: driver=13000
steers us to the 8-param v13 call, but the symbol resolves to v12
(9 params with failIdx), so the stream argument lands in a wrong slot
and we segfault — the exact crash this fix was supposed to prevent.

Reproduced on ion-user-9 with lmsysorg/sglang:dev (cu12.9 runtime):

    cudaDriverGetVersion()  = 13000
    cudaRuntimeGetVersion() = 12090
    v12 dispatch of dlsym'd symbol: cudaSuccess, exit 0
    v13 dispatch of dlsym'd symbol: Segmentation fault (core dumped)

Switching the signature-selection to cudaRuntimeGetVersion makes the
choice follow the loaded libcudart, which is what actually determines
the ABI. The existing cudaDriverGetVersion guard above is kept — it
remains the right knob for the capability check since cudaMemcpyBatch
requires a 12.8+ driver regardless of the runtime version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After the main `uv pip install -e python[...]` step, runners that carried
state from the pre-sgl-project#23119 (cu129) era keep `nvidia-cuda-runtime-cu12`
installed as an orphan (Required-by: empty) alongside the cu13 runtime.
Its libcudart.so.12 sits under `nvidia/cuda_runtime/lib/` while cu13's
lives under `nvidia/cu13/lib/`. Both dirs end up on LD_LIBRARY_PATH, so
cudnn_frontend_shim.h's probe

    for lib in ["libcudart.so.12", "libcudart.so.13"]:
        dlopen(lib)

loads both and throws:

    RuntimeError: Multiple libcudart libraries found:
    libcudart.so.12 and libcudart.so.13

Tests hit this during server setUpClass → CUDA graph capture (e.g.
test_nvfp4_gemm_sm120.py on stage-b-test-1-gpu-small). The same failure
reproduces on main, so this is not PR-specific — it's a leftover cleanup
step the cu13 migration missed.

Fix: uninstall nvidia-cuda-runtime-cu12 right after the main install.
Its install dir is disjoint from cu13's so the uninstall doesn't touch
any files shared with cu13 packages (a blunter sweep of all
`nvidia-*-cu12` breaks torch because several pairs share dirs under
`nvidia/<name>/lib/` and uninstalling one deletes files that the cu13
variant still references through its RECORD).

Reproduced and verified on 5090-novita-ci-runner-d (runner-1 container):

    before: libcudart.so.12 + libcudart.so.13 both loadable
    after : only libcudart.so.13 loadable, torch.cuda.randn works

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

Superseded by a fresh non-fork PR pushed to the upstream repo — lets the pr-test workflow run in parallel-dispatch mode (test_parallel_dispatch=true, skip_stage_health_check=true) without the fork-PR restrictions. The new PR will replace this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hicache Hierarchical Caching for SGLang run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant