Skip to content

Commit 3d3428e

Browse files
Kangyan-Zhouclaude
andcommitted
Use cudaRuntimeGetVersion for cudaMemcpyBatchAsync ABI dispatch
Addresses the review on #23172: #23172 (comment) cudaMemcpyBatchAsync is a libcudart (runtime) symbol; the ABI of the function dlsym'd into this process is owned by the libcudart that's actually loaded, not by the host's kernel driver. Dispatching on cudaDriverGetVersion() breaks in the common container case where a cu12 runtime is paired with a cu13-capable host driver: driver=13000 steers us to the 8-param v13 call, but the symbol resolves to v12 (9 params with failIdx), so the stream argument lands in a wrong slot and we segfault — the exact crash this fix was supposed to prevent. Reproduced on ion-user-9 with lmsysorg/sglang:dev (cu12.9 runtime): cudaDriverGetVersion() = 13000 cudaRuntimeGetVersion() = 12090 v12 dispatch of dlsym'd symbol: cudaSuccess, exit 0 v13 dispatch of dlsym'd symbol: Segmentation fault (core dumped) Switching the signature-selection to cudaRuntimeGetVersion makes the choice follow the loaded libcudart, which is what actually determines the ABI. The existing cudaDriverGetVersion guard above is kept — it remains the right knob for the capability check since cudaMemcpyBatch requires a 12.8+ driver regardless of the runtime version. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent eba2695 commit 3d3428e

1 file changed

Lines changed: 13 additions & 3 deletions

File tree

sgl-kernel/csrc/kvcacheio/transfer.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -812,9 +812,19 @@ inline void transfer_kv_page_first_direct_impl(
812812
return;
813813
}
814814

815-
// CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync.
816-
// Use runtime version to select the correct signature for binary portability.
817-
const bool use_v13_signature = driver_version >= 13000;
815+
// CUDA 13.0 removed the failIdx parameter from cudaMemcpyBatchAsync. The ABI
816+
// of the dlsym'd symbol is determined by the libcudart loaded in this process,
817+
// not the host driver — a cu12 runtime on a cu13 driver host (common in
818+
// containers) still exposes the 9-param v12 signature. Dispatching on the
819+
// driver version here would segfault in that case (verified empirically).
820+
// Use cudaRuntimeGetVersion so the signature follows the runtime.
821+
int runtime_version = 0;
822+
cudaError_t runtime_version_err = cudaRuntimeGetVersion(&runtime_version);
823+
if (runtime_version_err != cudaSuccess) {
824+
fallback_to_page_copy();
825+
return;
826+
}
827+
const bool use_v13_signature = runtime_version >= 13000;
818828

819829
size_t num_copies = 0;
820830
std::vector<void*> batch_srcs;

0 commit comments

Comments
 (0)