Skip to content

[Bug]: Mooncake EP hangs on GB200 MNNVL cluster — gdr_buffer allocated with cudaMalloc instead of cuMemCreate(FABRIC), causing 'Requested address not found' in nvlink_transport #1627

@tzulingk

Description

@tzulingk

Environment

  • GPU: NVIDIA GB200 (p6e-gb200.36xlarge on AWS, 4 GPUs/node)
  • Cluster: Multi-node NVLink fabric (MNNVL) with Kubernetes ComputeDomain
  • Mooncake version: v0.3.9 (built from source: git clone --branch v0.3.9 https://github.com/kvcache-ai/Mooncake.git, then BUILD_WITH_EP=1 ./scripts/build_wheel.sh)
  • CUDA: 13.0.1
  • Usage: --moe-a2a-backend mooncake + --elastic-ep-backend mooncake in SGLang (EP all-to-all, not P/D disaggregation)

Bug Description

When running SGLang with --moe-a2a-backend mooncake --elastic-ep-backend mooncake on a GB200 MNNVL cluster (with Kubernetes ComputeDomain providing NVLink fabric), mooncake EP hangs permanently at startup during the P2P handshake in init_torch_distributed. The process never recovers.


What Happens

Step 1 — Mooncake transfer engine initializes, finds 0 HCAs (EFA not mounted in pod), falls back to NVLink transport:

W topology.cpp:152] No RDMA devices found, check your device installation
I transfer_engine_impl.cpp:232] Topology discovery complete. Found 0 HCAs.
I transfer_engine_impl.cpp:253] Using cross-node NVLink transport (MC_FORCE_MNNVL or no HCA detected)

Step 2 — EP communication buffers (gdr_buffer, allocated via cudaMalloc in mooncake_ep_buffer.cpp:47) are registered with the NVLink transport. Since they are not cuMemCreate(CU_MEM_HANDLE_TYPE_FABRIC) allocations, they are flagged as local-only:

W nvlink_transport.cpp:347] Memory region 0x4216d660 is not allocated by cuMemCreate, but it can be used as local buffer
W nvlink_transport.cpp:347] Memory region 0x4216d770 is not allocated by cuMemCreate, but it can be used as local buffer
W nvlink_transport.cpp:347] Memory region 0x4216d880 is not allocated by cuMemCreate, but it can be used as local buffer
W nvlink_transport.cpp:347] Memory region 0x4216d990 is not allocated by cuMemCreate, but it can be used as local buffer

Step 3 — Remote nodes attempt cross-node NVLink access to these buffers. Since there is no fabric handle, the address lookup fails:

E nvlink_transport.cpp:497] Requested address 0x4216d998 to 0x4216d99c not found!
E nvlink_transport.cpp:497] Requested address 0x3c0d6fe8 to 0x3c0d6fec not found!
E nvlink_transport.cpp:497] Requested address 0x4216d99c to 0x4216d9a0 not found!
E nvlink_transport.cpp:497] Requested address 0x4216d994 to 0x4216d998 not found!
E nvlink_transport.cpp:497] Requested address 0x3c0d6fec to 0x3c0d6ff0 not found!
E nvlink_transport.cpp:497] Requested address 0x1519299c to 0x151929a0 not found!

Step 4 — Process hangs indefinitely. All 16 TP ranks are stuck in a collective waiting for the P2P handshake that never completes. After these 6 error lines at T+34s, the log goes completely silent. The pod ran for 3h52m before being manually deleted — no progress, no crash.


Root Cause

The Kubernetes ComputeDomain (NVLink fabric) is present and functional. The NVLink transport's own internal buffers (via allocatePinnedLocalMemory() in nvlink_transport.cpp:516-598) correctly use cuMemCreate(CU_MEM_HANDLE_TYPE_FABRIC).

However, mooncake EP's main communication buffer gdr_buffer in mooncake_ep_buffer.cpp:47 is allocated with plain cudaMalloc:

CUDA_CHECK(cudaMalloc(&gdr_buffer, num_ep_buffer_bytes));  // no fabric handle

This buffer is then registered for RDMA via ibv_reg_mr(), which only works when InfiniBand devices are present. When there are 0 HCAs and NVLink transport is selected instead, the cudaMalloc'd gdr_buffer cannot participate in cross-node NVLink fabric transfers because it has no CU_MEM_HANDLE_TYPE_FABRIC handle. Remote nodes cannot resolve the address → hang.


Why InfiniBand works fine

On H100 clusters with InfiniBand (e.g., EOS/DGX), mooncake EP works correctly because:

  1. ibv_reg_mr() successfully registers the cudaMalloc buffer with the IB driver
  2. RDMA can transfer any GPU-pinned memory regardless of how it was allocated
  3. NVLink transport is never selected (HCAs found)

The MNNVL/NVLink path is only taken when HCAs = 0.


Proposed Fix

In mooncake_ep_buffer.cpp, when NVLink fabric is available (detectable via supportFabricMem() from nvlink_transport.cpp), allocate gdr_buffer using cuMemCreate(CU_MEM_HANDLE_TYPE_FABRIC) instead of cudaMalloc, so it gets a fabric handle and can be accessed cross-node.

This is analogous to what allocatePinnedLocalMemory() already does in nvlink_transport.cpp.


Related Issues

This appears to be the same underlying gap: mooncake EP does not yet allocate its communication buffers with fabric handles for MNNVL use cases.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions