Skip to content

Fix MNNVL warmup hang: skip warmup when fabric mem is available#1644

Merged
UNIDY2002 merged 3 commits intokvcache-ai:mainfrom
he-yufeng:fix/mnnvl-warmup-hang
Mar 14, 2026
Merged

Fix MNNVL warmup hang: skip warmup when fabric mem is available#1644
UNIDY2002 merged 3 commits intokvcache-ai:mainfrom
he-yufeng:fix/mnnvl-warmup-hang

Conversation

@he-yufeng
Copy link
Copy Markdown
Contributor

Summary

Fixes #1639. Follow-up to #1629 (EP buffer fix).

On GB200 MNNVL clusters, ConnectionContext warmup buffers are allocated from CPU heap (new int32_t[]). The NVLink transport can only access cuMemCreate(CU_MEM_HANDLE_TYPE_FABRIC) memory cross-node, so remote writes to these heap buffers silently fail and waitUntilAllConnected() blocks forever.

Fix

When supportFabricMem() is true (MNNVL cluster), skip the warmup handshake entirely and transition directly to CONNECTED after the store key exchange + openSegment(). The fabric topology already guarantees connectivity between all peers in a ComputeDomain, so the warmup write is redundant.

This is Option A from the issue — simpler than allocating warmup buffers with cuMemCreate(FABRIC) (Option B), and avoids the complexity of fabric handle exchange for what is essentially a connectivity check that the fabric infrastructure already guarantees.

Changes

  • connection_poller.cpp: Add supportFabricMem() (same check as nvlink_transport.cpp and mooncake_ep_buffer.cpp)
  • connection_poller.cpp: Skip buffer allocation and warmup write when skip_warmup_ is true, go straight to CONNECTED
  • connection_poller.h: Add skip_warmup_ member
  • setup.py: Link cuda driver library for cuDeviceGetAttribute

Testing

Needs verification on an MNNVL cluster (GB200 NVSwitch domain). IB clusters are unaffected — supportFabricMem() returns false, so the warmup path is unchanged.

cc @tzulingk @UNIDY2002

On GB200 MNNVL clusters, the warmup handshake in ConnectionContext
allocates send/recv buffers from the CPU heap. The NVLink transport
can only access cuMemCreate(CU_MEM_HANDLE_TYPE_FABRIC) memory
cross-node, so remote writes to these heap buffers silently fail
and the state machine retries forever.

Since MNNVL fabric guarantees connectivity between all peers in a
ComputeDomain, we can safely skip the warmup write entirely when
supportFabricMem() is true. The store key exchange alone is
sufficient proof of peer reachability.

Changes:
- Add supportFabricMem() to connection_poller.cpp (same check as
  nvlink_transport.cpp and mooncake_ep_buffer.cpp)
- Skip warmup buffer allocation and warmup write when on MNNVL,
  transition directly to CONNECTED after opening the segment
- Guard destructor against null warmup buffers
- Link CUDA driver library in setup.py for cuDeviceGetAttribute

Fixes kvcache-ai#1639
@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a deadlock issue encountered on MNNVL clusters during the ConnectionContext warmup phase. It optimizes the connection process by bypassing the warmup handshake when fabric memory support is detected, thus preventing failed remote writes to CPU heap buffers. Additionally, it ensures the CUDA driver library is linked for proper device attribute retrieval.

Highlights

  • MNNVL Warmup Fix: This PR addresses a hang on GB200 MNNVL clusters by skipping the ConnectionContext warmup handshake when fabric memory is available. This avoids issues with remote writes to CPU heap buffers, which are not fabric-accessible.
  • Fabric Topology Guarantee: The fix leverages the fabric topology, which already guarantees connectivity between all peers in a ComputeDomain, making the warmup write redundant.
  • CUDA Driver Library Linking: The PR also links the cuda driver library in setup.py to enable cuDeviceGetAttribute.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • connection_poller.cpp
    • Added supportFabricMem() function to check for fabric memory support.
    • Modified ConnectionContext constructor to skip warmup when skip_warmup_ is true.
    • Modified pollPeer() to skip warmup write when skip_warmup_ is true.
  • connection_poller.h
    • Added skip_warmup_ member to ConnectionContext class.
  • setup.py
    • Linked cuda driver library for cuDeviceGetAttribute.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@UNIDY2002
Copy link
Copy Markdown
Collaborator

The CI environment seems to have issues with CUDA linking. You might want to take a look at how it's done in #1629's mooncake-ep/setup.py and follow the same approach.


Sent from UNIDY2002's OpenClaw

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to resolve a hang issue on MNNVL clusters by skipping the warmup handshake when fabric memory is available. However, a critical vulnerability has been identified: a null pointer dereference in the CONNECTED state handling within pollPeer. This occurs because the logic fails to account for null pointers in the connection failure recovery path, leading to an application crash if a peer disconnects when warmup is skipped. Additionally, the implementation could be improved by addressing code duplication in supportFabricMem and pollPeer, and by adding error checking for a CUDA driver call.

// remote NVLink writes to them will fail. The fabric topology already
// guarantees connectivity, so we skip the warmup handshake entirely.
warmup_send_region_ = nullptr;
warmup_recv_region_ = nullptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

A null pointer dereference occurs here when skip_warmup_ is true. In the constructor, warmup_recv_region_ is set to nullptr if skip_warmup_ is enabled. However, the CONNECTED state logic in pollPeer (lines 304-306) attempts to reset this region by dereferencing the pointer without a null check. If a connection failure is detected (e.g., a peer disconnects), the poller thread will crash with a segmentation fault, leading to a Denial of Service.

To fix this, ensure that the warmup region is only reset if it was actually allocated.

Comment on lines +22 to +36
static bool supportFabricMem() {
if (getenv("MC_USE_NVLINK_IPC")) return false;

int num_devices = 0;
cudaError_t err = cudaGetDeviceCount(&num_devices);
if (err != cudaSuccess || num_devices == 0) return false;

for (int dev = 0; dev < num_devices; ++dev) {
int supported = 0;
cuDeviceGetAttribute(
&supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev);
if (!supported) return false;
}
return true;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is a good addition. I have two suggestions for improvement:

  1. Code Duplication: As noted in the PR description, this function is duplicated across a few files. To improve long-term maintainability, it would be ideal to move this to a shared utility file (e.g., cuda_utils.h/.cpp) to avoid having to update it in multiple places in the future.

  2. Error Handling: The return value from cuDeviceGetAttribute is not checked. It's good practice to check for errors from CUDA driver API calls to ensure robustness.

Here's a suggestion that adds the error check:

static bool supportFabricMem() {
    if (getenv("MC_USE_NVLINK_IPC")) return false;

    int num_devices = 0;
    cudaError_t err = cudaGetDeviceCount(&num_devices);
    if (err != cudaSuccess || num_devices == 0) return false;

    for (int dev = 0; dev < num_devices; ++dev) {
        int supported = 0;
        CUresult cu_err = cuDeviceGetAttribute(
            &supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev);
        if (cu_err != CUDA_SUCCESS) {
            // Consider logging a warning here for easier debugging.
            return false;
        }
        if (!supported) return false;
    }
    return true;
}

Comment on lines +176 to +188
if (skip_warmup_) {
// MNNVL: fabric guarantees connectivity, skip warmup write
// since CPU heap buffers aren't fabric-accessible anyway.
meta_->peerConnected[pollingRank] = true;
global_peerConnected_[globalPollingRank] = true;
peerState.state = PeerConnectionState::CONNECTED;
{
std::lock_guard<std::mutex> lock(backend_wakeup_mutex_);
totalConnectedPeers_.fetch_add(1,
std::memory_order_release);
if (isAllPeerConnected()) backend_wakeup_cv_.notify_all();
}
} else if (pollingRank <= rank_) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code to mark a peer as connected is also present in the state handlers for WAITING_WARMUP_TRANSFER (lines 218-231) and WAITING_PEER_WARMUP (lines 249-258). To reduce code duplication and improve readability, consider extracting this logic into a private helper method within the ConnectionContext class, for example void markPeerAsConnected(int pollingRank, uint64_t globalPollingRank).

Use CUDA_HOME/lib64/stubs/libcuda.so stub path, same approach as
mooncake-ep/setup.py. The CI environment doesn't have libcuda.so
in the standard library search path.
@he-yufeng
Copy link
Copy Markdown
Contributor Author

Good catch, thanks! Fixed in 2761535 — now using the same CUDA_HOME/lib64/stubs/libcuda.so approach as mooncake-ep/setup.py.

Copy link
Copy Markdown
Collaborator

@UNIDY2002 UNIDY2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your work!

@UNIDY2002 UNIDY2002 merged commit 1aeeffc into kvcache-ai:main Mar 14, 2026
16 of 17 checks passed
whn09 pushed a commit to whn09/Mooncake that referenced this pull request Apr 4, 2026
…che-ai#1644)

* Fix MNNVL warmup hang: skip warmup when fabric mem is available

On GB200 MNNVL clusters, the warmup handshake in ConnectionContext
allocates send/recv buffers from the CPU heap. The NVLink transport
can only access cuMemCreate(CU_MEM_HANDLE_TYPE_FABRIC) memory
cross-node, so remote writes to these heap buffers silently fail
and the state machine retries forever.

Since MNNVL fabric guarantees connectivity between all peers in a
ComputeDomain, we can safely skip the warmup write entirely when
supportFabricMem() is true. The store key exchange alone is
sufficient proof of peer reachability.

Changes:
- Add supportFabricMem() to connection_poller.cpp (same check as
  nvlink_transport.cpp and mooncake_ep_buffer.cpp)
- Skip warmup buffer allocation and warmup write when on MNNVL,
  transition directly to CONNECTED after opening the segment
- Guard destructor against null warmup buffers
- Link CUDA driver library in setup.py for cuDeviceGetAttribute

Fixes kvcache-ai#1639
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] MNNVL hang after #1629: mooncake-pg warmup buffers (CPU heap) are not NVLink fabric-accessible

4 participants