Skip to content

Fix NVLink IPC offset corruption for sub-allocated GPU tensors#1622

Merged
alogfans merged 3 commits intokvcache-ai:mainfrom
ishandhanani:idhanani/fix-nvlink-ipc-offset
Mar 12, 2026
Merged

Fix NVLink IPC offset corruption for sub-allocated GPU tensors#1622
alogfans merged 3 commits intokvcache-ai:mainfrom
ishandhanani:idhanani/fix-nvlink-ipc-offset

Conversation

@ishandhanani
Copy link
Copy Markdown
Contributor

@ishandhanani ishandhanani commented Mar 6, 2026

Problem

registerLocalMemory() in IntraNodeNvlinkTransport stores the caller-provided pointer as the buffer base address. When frameworks like PyTorch sub-allocate tensors within larger cudaMalloc segments (via caching allocators), this pointer is not the cudaMalloc base. Since cudaIpcGetMemHandle always returns a handle for the entire cudaMalloc allocation, relocateSharedMemoryAddress() computes the wrong offset on the target side -- the sub-allocation offset within the segment is lost.

This corrupts all small tensors (<1MB) that share a cudaMalloc segment. Large tensors (>1MB) are unaffected because PyTorch gives them their own segment.

Fix

Use cuMemGetAddressRange() to resolve the true cudaMalloc base address and allocation size before registration. Register at segment granularity and skip duplicates when multiple tensors share the same segment.

Testing

Reproduced with a two-process test (seed on GPU 0, target on GPU 1) transferring tensors of varying sizes (64B, 16KB, 128KB, 512KB, 4MB). Before fix: 4/6 tensors corrupted. After fix: 6/6 correct.

Also validated end-to-end with SGLang remote instance weight loading (Qwen3-235B-A22B-FP8, TP=4, NVLink transport).

Note

Claude code helped me find this bug while working on sharding weights for transfer in sgl-project/sglang#19983. But I reviewed the fix and have tried it and it works

…nsport

When framework caching allocators (PyTorch, etc.) sub-allocate tensors
within larger cudaMalloc segments, cudaIpcGetMemHandle returns a handle
for the entire segment, not the sub-allocation. The existing code stored
the sub-allocation address as the buffer base, causing relocateShared-
MemoryAddress to compute an incorrect offset on the target side. Small
tensors (<1MB) sharing a segment would read from the wrong location.

Fix: use cuMemGetAddressRange() to resolve the true cudaMalloc base
address before registration. Register at segment granularity and skip
duplicate registrations when multiple tensors share the same segment.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue in the NVLink IPC transport mechanism where sub-allocated GPU tensors, particularly those smaller than 1MB, were experiencing data corruption due to incorrect memory registration. The core problem stemmed from registering sub-pointers instead of the underlying cudaMalloc base addresses. The solution involves accurately identifying and registering the true cudaMalloc segments, thereby ensuring proper offset calculation and preventing data integrity issues during inter-process communication.

Highlights

  • Problem Identification: Identified and addressed a bug where registerLocalMemory in IntraNodeNvlinkTransport incorrectly used sub-allocated tensor pointers instead of the true cudaMalloc base addresses, leading to IPC offset corruption for small tensors (<1MB) when frameworks like PyTorch sub-allocate memory.
  • Resolution Strategy: Implemented a fix to use cuMemGetAddressRange() to resolve the actual cudaMalloc base address and allocation size. Memory is now registered at the segment granularity, and duplicate registrations for the same cudaMalloc segment are skipped.
  • Impact and Validation: The fix prevents corruption of sub-allocated tensors, ensuring correct data transfer. Validation included a two-process test showing 6/6 tensors correct after the fix (from 4/6 corrupted) and end-to-end validation with SGLang remote instance weight loading.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • mooncake-transfer-engine/include/transport/intranode_nvlink_transport/intranode_nvlink_transport.h
    • Included <unordered_set> for efficient tracking of registered memory blocks.
    • Added registered_base_addrs_ as a private member to store uint64_t representations of CUDA memory base pointers.
  • mooncake-transfer-engine/src/transport/intranode_nvlink_transport/intranode_nvlink_transport.cpp
    • Included <cuda.h> to access cuMemGetAddressRange.
    • In registerLocalMemory, implemented cuMemGetAddressRange to determine the actual cudaMalloc base address and allocation size for a given pointer.
    • Added a check using registered_base_addrs_ to prevent re-registering the same cudaMalloc segment, returning 0 if already registered.
    • Changed cudaIpcGetMemHandle to operate on the cudaMalloc base address instead of the sub-allocated pointer.
    • Updated the BufferDesc to store the cudaMalloc base address and its full allocation size.
    • In unregisterLocalMemory, added logic to remove the addr from registered_base_addrs_ within a mutex lock.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the NVLink IPC offset corruption by resolving and registering the true base address of CUDA memory segments, with the registerLocalMemory implementation appearing sound. However, a critical vulnerability exists in the unregistration logic: unregisterLocalMemory fails to resolve the base address of sub-allocated tensors, leading to resource leaks and preventing correct re-registration. Additionally, a race condition was identified in the unregistration flow, which could lead to an inconsistent state if metadata registration fails. These issues require immediate attention to ensure the robustness and correctness of the transport engine.

- unregisterLocalMemory: resolve base address via cuMemGetAddressRange
  before erasing from tracking set and metadata (matches register path)
- registerLocalMemory: insert into registered_base_addrs_ only after
  addLocalMemoryBuffer succeeds to avoid inconsistent state on failure
Comment on lines 18 to 19
#include <cuda.h>
#include "cuda_alike.h"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alogfans can you check this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is double defined.

@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@alogfans alogfans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relocateSharedMemoryAddress should be changed accordingly.

@stmatengss
Copy link
Copy Markdown
Collaborator

Intro-node NVLink feature is contributed by @TTThanos . PTAL. Thanks

@stmatengss
Copy link
Copy Markdown
Collaborator

relocateSharedMemoryAddress should be changed accordingly.

True. @ishandhanani, could you also modify relocateSharedMemoryAddress? Alternatively, we can implement it in the next PR.

Comment on lines +303 to +321
// Resolve the true cudaMalloc base address. Framework caching allocators
// (PyTorch, etc.) sub-allocate tensors within larger cudaMalloc segments.
// cudaIpcGetMemHandle always returns a handle for the entire segment, so
// we must register at segment granularity for correct IPC relocation.
CUdeviceptr base_ptr = 0;
size_t alloc_size = 0;
CUresult cu_err =
cuMemGetAddressRange(&base_ptr, &alloc_size, (CUdeviceptr)addr);
if (cu_err != CUDA_SUCCESS) {
LOG(ERROR) << "IntraNodeNvlinkTransport: cuMemGetAddressRange failed "
<< "for addr " << addr << " (error " << cu_err << ")";
return -1;
}

// Skip if this cudaMalloc block is already registered
if (registered_base_addrs_.count((uint64_t)base_ptr)) {
return 0;
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Learn a lot from these codes.

@stmatengss stmatengss requested a review from alogfans March 12, 2026 03:17
@stmatengss stmatengss dismissed alogfans’s stale review March 12, 2026 03:18

I will update relocate API.

@alogfans alogfans merged commit 68c1da7 into kvcache-ai:main Mar 12, 2026
17 checks passed
@TTThanos
Copy link
Copy Markdown
Contributor

Intro-node NVLink feature is contributed by @TTThanos . PTAL. Thanks

Sorry for the late response — I saw this message a bit too late. Thanks for the @alogfans 's contribution! I noticed the issue has already been addressed !

whn09 pushed a commit to whn09/Mooncake that referenced this pull request Apr 4, 2026
…he-ai#1622)

* Fix IPC offset corruption for sub-allocated GPU tensors in NVLink transport

When framework caching allocators (PyTorch, etc.) sub-allocate tensors
within larger cudaMalloc segments, cudaIpcGetMemHandle returns a handle
for the entire segment, not the sub-allocation. The existing code stored
the sub-allocation address as the buffer base, causing relocateShared-
MemoryAddress to compute an incorrect offset on the target side. Small
tensors (<1MB) sharing a segment would read from the wrong location.

Fix: use cuMemGetAddressRange() to resolve the true cudaMalloc base
address before registration. Register at segment granularity and skip
duplicate registrations when multiple tensors share the same segment.

* Address review: fix unregister and insert-before-confirm

- unregisterLocalMemory: resolve base address via cuMemGetAddressRange
  before erasing from tracking set and metadata (matches register path)
- registerLocalMemory: insert into registered_base_addrs_ only after
  addLocalMemoryBuffer succeeds to avoid inconsistent state on failure

* Update mooncake-transfer-engine/src/transport/intranode_nvlink_transport/intranode_nvlink_transport.cpp

---------

Co-authored-by: Ishan Dhanani <ishan@dhanani.dev>
Co-authored-by: Teng Ma <teng-ma@linux.alibaba.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants