Fix segfault in cudaMemcpyBatchAsync on CUDA 13.0#23136
Fix segfault in cudaMemcpyBatchAsync on CUDA 13.0#23136Kangyan-Zhou merged 8 commits intosgl-project:mainfrom
Conversation
CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync, causing a segfault due to argument mismatch when called via dlsym. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for the updated cudaMemcpyBatchAsync signature in CUDA 13.0, which removed the failIdx parameter. The implementation uses conditional compilation to select the appropriate function signature and call site. Feedback indicates that using compile-time macros for runtime symbol loading via dlsym breaks binary portability between CUDA versions; instead, the runtime version should be checked. Additionally, a potential memory safety issue was identified where the attrs_idxs array size may not match the number of copies, leading to undefined behavior.
Replace compile-time #if CUDA_VERSION with runtime driver_version check to select the correct function signature. This ensures binary portability across CUDA 12.8 and 13.0 environments. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Thanks @yhyang201 ! |
Port PR sgl-project#23136 (Yuhao Yang): cudaMemcpyBatchAsync lost its failIdx parameter in CUDA 13, so the dlsym-based call was passing the stream handle at the wrong slot and segfaulting inside cuMemcpyBatchAsync_v2. Use driver_version at runtime to dispatch to either the CUDA 12 or CUDA 13 signature. With the segfault fixed, move the 7 hicache tests that were parked under test/manual in PR sgl-project#23119 and subsequent cu13 flake sweeps back into test/registered so they run in CI again: - hicache/test_hicache_storage.py - hicache/test_hicache_storage_3fs_backend.py - hicache/test_hicache_storage_file_backend.py - hicache/test_hicache_storage_mooncake_backend.py - hicache/test_hicache_storage_runtime_attach_detach.py - hicache/test_hicache_variants.py - 4-gpu-models/test_qwen35_hicache.py TODO "move back after fixed" docstrings are stripped and the register_cuda_ci call that was dropped from the mooncake backend test on its way to manual is restored. Co-Authored-By: Yuhao Yang <yhyang201@users.noreply.github.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
|
||
| // CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync. | ||
| // Use runtime version to select the correct signature for binary portability. | ||
| const bool use_v13_signature = driver_version >= 13000; |
There was a problem hiding this comment.
cudaMemcpyBatchAsync is a libcudart (runtime) symbol, so the ABI of the function dlsym'd into the process is owned by whichever libcudart is actually loaded — not by the host's kernel driver. cudaDriverGetVersion() reports the driver version, which in containerized setups routinely diverges from the runtime: a cu12 runtime (e.g. lmsysorg/sglang:dev, cu12.9) paired with a cu13-capable host driver is common. In that case driver_version = 13000 steers us to the 8-param v13 call, but the dlsym'd symbol is the 9-param v12 variant — the stream argument lands in a wrong slot and we segfault. Same class of crash this PR is trying to fix.
Reproduced on a cu13 host / cu12.9 container:
cudaDriverGetVersion() = 13000
cudaRuntimeGetVersion() = 12090
call with v12 dispatch -> cudaSuccess, exit 0
call with v13 dispatch -> Segmentation fault (core dumped), exit 139
The existing cudaDriverGetVersion gate on the capability check (< 12080 -> fallback) is fine — that's the right knob for "is the driver new enough to support this at all". It's just the signature selection that needs to follow the runtime.
Suggested fix:
| const bool use_v13_signature = driver_version >= 13000; | |
| // CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync. The ABI | |
| // of the dlsym'd symbol is determined by the libcudart loaded in this process, | |
| // not the host driver — a cu12 runtime on a cu13 driver host (common in | |
| // containers) still exposes the 9-param v12 signature. Dispatching on the | |
| // driver version here would segfault in that case (verified empirically). | |
| // Use cudaRuntimeGetVersion so the signature follows the runtime. | |
| int runtime_version = 0; | |
| cudaError_t runtime_version_err = cudaRuntimeGetVersion(&runtime_version); | |
| if (runtime_version_err != cudaSuccess) { | |
| fallback_to_page_copy(); | |
| return; | |
| } | |
| const bool use_v13_signature = runtime_version >= 13000; |
FYI I've already applied this fix on the port of your PR in #23172 (3d3428e4f) if you want to cherry-pick. Thanks for the original fix!
Addresses the review on sgl-project#23172: sgl-project#23172 (comment) cudaMemcpyBatchAsync is a libcudart (runtime) symbol; the ABI of the function dlsym'd into this process is owned by the libcudart that's actually loaded, not by the host's kernel driver. Dispatching on cudaDriverGetVersion() breaks in the common container case where a cu12 runtime is paired with a cu13-capable host driver: driver=13000 steers us to the 8-param v13 call, but the symbol resolves to v12 (9 params with failIdx), so the stream argument lands in a wrong slot and we segfault — the exact crash this fix was supposed to prevent. Reproduced on ion-user-9 with lmsysorg/sglang:dev (cu12.9 runtime): cudaDriverGetVersion() = 13000 cudaRuntimeGetVersion() = 12090 v12 dispatch of dlsym'd symbol: cudaSuccess, exit 0 v13 dispatch of dlsym'd symbol: Segmentation fault (core dumped) Switching the signature-selection to cudaRuntimeGetVersion makes the choice follow the loaded libcudart, which is what actually determines the ABI. The existing cudaDriverGetVersion guard above is kept — it remains the right knob for the capability check since cudaMemcpyBatch requires a 12.8+ driver regardless of the runtime version. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
|
|
/tag-and-rerun-ci |
Addresses code-review feedback on the sibling PR sgl-project#23183: sgl-project#23183 (comment) The runtime version is constant for the process lifetime, so cache the cudaRuntimeGetVersion result and the derived use_v13_signature as static locals (thread-safe static init in C++11+). Keeps the KV-transfer hot path free of a redundant runtime-API call per invocation. Other diff in this commit is clang-format reflowing the v12/v13 dlsym call sites to the repo's column-limit style — no semantic change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Motivation
CUDA 13.0 removed the
failIdxparameter fromcudaMemcpyBatchAsync(8 params), but the code was using the CUDA 12.8 signature (9 params) viadlsym. This caused thestreamargument to be misaligned — the runtime received a stack pointer as the CUDA stream handle, resulting in a segfault insidecuMemcpyBatchAsync_v2.Fixed by using runtime
driver_versiondetection to select the correct function signature, ensuring binary portability across CUDA 12.8 and 13.0 environments.This may not be the optimal approach — the main intent of this PR is to identify the root cause of the segfault and provide a working fix.
cc @alisonshao @Fridge003
Modifications
Accuracy Tests
Environment
Result
All 192 tests in
sgl-kernel/tests/test_kvcacheio.pypass. Previously crashed attest_transfer_kv_pf_direct(~37%).Before (segfault at the first
test_transfer_kv_pf_directcase, ~37%):After (192/192 passed):
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci