Skip to content

Fix Cuda Ipc Tuto#4251

Merged
samnordmann merged 12 commits intomainfrom
cuda_ipc_tuto
Apr 16, 2025
Merged

Fix Cuda Ipc Tuto#4251
samnordmann merged 12 commits intomainfrom
cuda_ipc_tuto

Conversation

@samnordmann
Copy link
Collaborator

Fix #3912 after it has been reverted by #4248

@samnordmann
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 15, 2025

Review updated until commit 5390dbd

Description

  • Added CUDA IPC tests for multi-device communication

  • Guarded TCPStore method with NVFUSER_DISTRIBUTED

  • Skipped tests for single device

  • Removed explicit linking with CUDA


Changes walkthrough 📝

Relevant files
Tests
test_multidevice_ipc.cpp
Added CUDA IPC tests                                                                         

tests/cpp/test_multidevice_ipc.cpp

  • Added new tests for CUDA IPC memory handle operations
  • Included necessary headers and namespace
  • Implemented toBytes and fromBytes utility functions
  • Added tests for pointer arithmetic on sender and receiver sides
  • +218/-0 
    Configuration changes
    CMakeLists.txt
    Update CMakeLists.txt for new test                                             

    CMakeLists.txt

    • Added test_multidevice_ipc.cpp to the list of test sources
    +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Pointer Arithmetic

    The PR introduces tests for CUDA IPC with pointer arithmetic. Ensure that the pointer arithmetic logic is correct and that the tests accurately reflect the expected behavior on both the sender and receiver sides.

    communicator_->barrier();
    
    // Import Ipc Handle
    auto peer_ipc_handle = fromBytes<cudaIpcMemHandle_t>(
        store->get("ipc_handle_" + std::to_string(peer_rank)));
    int64_t* peer_d_ptr;
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
        (void**)&peer_d_ptr, peer_ipc_handle, cudaIpcMemLazyEnablePeerAccess));
    
    // Validate, by reading the second value in the buffer (c.f. the "+1" offset)
    int64_t peer_value;
    NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
        &peer_value, peer_d_ptr + 1, kBufferSize / 2, cudaMemcpyDeviceToHost));
    EXPECT_EQ(2 * peer_rank + 1, peer_value);
    
    Error Handling

    Verify that the tests handle potential errors gracefully, such as when cudaIpcGetMemHandle or cudaIpcOpenMemHandle fail.

    cudaIpcMemHandle_t ipc_handle;
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcGetMemHandle(&ipc_handle, d_ptr));
    // As a convenience, we use the TCP store to exchange out-of-band the IPC
    // handle as raw data
    auto store = communicator_->getTcpStore();
    store->set("ipc_handle_" + std::to_string(rank), toBytes(ipc_handle));
    
    // Wait for all ranks to finish exporting the IPC handle
    communicator_->barrier();
    
    // Import Ipc Handle
    auto peer_ipc_handle = fromBytes<cudaIpcMemHandle_t>(
        store->get("ipc_handle_" + std::to_string((rank + 1) % num_devices)));
    void* peer_d_ptr;
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
        &peer_d_ptr, peer_ipc_handle, cudaIpcMemLazyEnablePeerAccess));
    
    Test Coverage

    Ensure that the tests cover all relevant scenarios, including edge cases like the smallest and largest possible buffer sizes, and verify that the tests are comprehensive enough to catch regressions.

    TEST_F(IpcTest, IpcMemHandle) {
      if (communicator_->size() == 1) {
        GTEST_SKIP() << "Skipping test for single device";
      }
    #ifdef NVFUSER_DISTRIBUTED
      // Allocate and setup GPU buffers
      constexpr size_t kBufferSize = sizeof(int64_t);
      const int64_t num_devices = communicator_->size();
      const int64_t rank = communicator_->deviceId();
    
      NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(rank));
    
      void* d_ptr;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_ptr, kBufferSize));
      const int64_t value = rank;
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaMemcpy(d_ptr, &value, kBufferSize, cudaMemcpyHostToDevice));
    
      // Export Ipc Handle
      cudaIpcMemHandle_t ipc_handle;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcGetMemHandle(&ipc_handle, d_ptr));
      // As a convenience, we use the TCP store to exchange out-of-band the IPC
      // handle as raw data
      auto store = communicator_->getTcpStore();
      store->set("ipc_handle_" + std::to_string(rank), toBytes(ipc_handle));
    
      // Wait for all ranks to finish exporting the IPC handle
      communicator_->barrier();
    
      // Import Ipc Handle
      auto peer_ipc_handle = fromBytes<cudaIpcMemHandle_t>(
          store->get("ipc_handle_" + std::to_string((rank + 1) % num_devices)));
      void* peer_d_ptr;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
          &peer_d_ptr, peer_ipc_handle, cudaIpcMemLazyEnablePeerAccess));
    
      // Validate
      int64_t peer_value;
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaMemcpy(&peer_value, peer_d_ptr, kBufferSize, cudaMemcpyDeviceToHost));
      EXPECT_EQ((value + 1) % num_devices, peer_value);
    
      // Clean up
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr));
      NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr));
    #else // NVFUSER_DISTRIBUTED
      GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined";
    #endif // NVFUSER_DISTRIBUTED
    }
    
    TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) {
      if (communicator_->size() == 1) {
        GTEST_SKIP() << "Skipping test for single device";
      }
    #ifdef NVFUSER_DISTRIBUTED
      // TL;DR: We can do pointer arithmetic on the importer side. IOW, the pointer
      // can be used as a regular pointer on the importer side.
    
      // Allocate GPU memory. Set up a buffer with two int values.
      constexpr size_t kBufferSize = 2 * sizeof(int64_t);
      const int64_t num_devices = communicator_->size();
      const int64_t rank = communicator_->deviceId();
      const int64_t peer_rank = (rank + 1) % num_devices;
    
      NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(rank));
    
      void* d_ptr;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_ptr, kBufferSize));
    
      // Set up the buffer
      std::vector<int64_t> values;
      values.push_back(2 * rank);
      values.push_back(2 * rank + 1);
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaMemcpy(d_ptr, values.data(), kBufferSize, cudaMemcpyHostToDevice));
    
      // Export Ipc Handle
      cudaIpcMemHandle_t ipc_handle;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcGetMemHandle(&ipc_handle, d_ptr));
      auto store = communicator_->getTcpStore();
      store->set("ipc_handle_" + std::to_string(rank), toBytes(ipc_handle));
    
      // Wait for all ranks to finish exporting the IPC handle
      communicator_->barrier();
    
      // Import Ipc Handle
      auto peer_ipc_handle = fromBytes<cudaIpcMemHandle_t>(
          store->get("ipc_handle_" + std::to_string(peer_rank)));
      int64_t* peer_d_ptr;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
          (void**)&peer_d_ptr, peer_ipc_handle, cudaIpcMemLazyEnablePeerAccess));
    
      // Validate, by reading the second value in the buffer (c.f. the "+1" offset)
      int64_t peer_value;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
          &peer_value, peer_d_ptr + 1, kBufferSize / 2, cudaMemcpyDeviceToHost));
      EXPECT_EQ(2 * peer_rank + 1, peer_value);
    
      // Clean up
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr));
      NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr));
    #else // NVFUSER_DISTRIBUTED
      GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined";
    #endif // NVFUSER_DISTRIBUTED
    }
    
    TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) {
      if (communicator_->size() == 1) {
        GTEST_SKIP() << "Skipping test for single device";
      }
    #ifdef NVFUSER_DISTRIBUTED
      // TL;DR: We CANNOT do pointer arithmetic on the exporter side! The IPC handle
      // points to the beginning of the allocated buffer.
    
      // Allocate GPU memory. Set up a buffer with two int values.
      constexpr size_t kBufferSize = 2 * sizeof(int64_t);
      const int64_t num_devices = communicator_->size();
      const int64_t rank = communicator_->deviceId();
      const int64_t peer_rank = (rank + 1) % num_devices;
    
      NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(rank));
    
      int64_t* d_ptr;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_ptr, kBufferSize));
    
      std::vector<int64_t> values;
      values.push_back(2 * rank);
      values.push_back(2 * rank + 1);
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaMemcpy(d_ptr, values.data(), kBufferSize, cudaMemcpyHostToDevice));
    
      // Export Ipc Handle
      cudaIpcMemHandle_t ipc_handle;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcGetMemHandle(&ipc_handle, d_ptr + 1));
      auto store = communicator_->getTcpStore();
      store->set("ipc_handle_" + std::to_string(rank), toBytes(ipc_handle));
    
      // Wait for all ranks to finish exporting the IPC handle
      communicator_->barrier();
    
      // Import Ipc Handle
      auto peer_ipc_handle = fromBytes<cudaIpcMemHandle_t>(
          store->get("ipc_handle_" + std::to_string(peer_rank)));
      int64_t* peer_d_ptr;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
          (void**)&peer_d_ptr, peer_ipc_handle, cudaIpcMemLazyEnablePeerAccess));
    
      // Validate, noticing that the pointer is not offset by 1, contrarily to the
      // offset used in the exporter side.
      int64_t peer_value;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
          &peer_value, peer_d_ptr, kBufferSize / 2, cudaMemcpyDeviceToHost));
      EXPECT_EQ(
          2 * peer_rank,
          peer_value); // and not 2 * peer_rank + 1 as could be expected!
    
      // Clean up
      NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr));
      NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr));
    #else // NVFUSER_DISTRIBUTED
      GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined";
    #endif // NVFUSER_DISTRIBUTED
    }

    @samnordmann samnordmann requested a review from wujingyue April 15, 2025 14:12
    @naoyam
    Copy link
    Collaborator

    naoyam commented Apr 15, 2025

    Looks like some of the IPC tests are failing.

    @samnordmann
    Copy link
    Collaborator Author

    Looks like some of the IPC tests are failing.

    Oh, I see, thanks! We need to skip the test for single device. I'm pushing a patch.

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue changed the base branch from main to wjy/rollback April 15, 2025 20:42
    @wujingyue wujingyue changed the base branch from wjy/rollback to wjy/rollforward April 15, 2025 20:44
    @wujingyue
    Copy link
    Collaborator

    LGTM! Thanks for the fix. Next time, you can get a faster review by rebasing this PR on the revert of the rollback and then changing the base of this PR to that. This would show the diff between version 1 and version 2.

    For the record, the main diff from the first version is:

    $ git diff 8f11fb5d40145eeeec103d19d03dff93288cbc22..cuda_ipc_tuto
    <lots of noise>
    
    diff --git a/tests/cpp/test_multidevice_ipc.cpp b/tests/cpp/test_multidevice_ipc.cpp
    index 6ac373f9..30daf6db 100644
    --- a/tests/cpp/test_multidevice_ipc.cpp
    +++ b/tests/cpp/test_multidevice_ipc.cpp
    @@ -31,11 +31,17 @@ const T& fromBytes(const std::vector<uint8_t>& bytes) {
     using IpcTest = MultiDeviceTest;
    
     TEST_F(IpcTest, IpcMemHandle) {
    +  if (communicator_->size() == 1) {
    +    GTEST_SKIP() << "Skipping test for single device";
    +  }
     #ifdef NVFUSER_DISTRIBUTED
       // Allocate and setup GPU buffers
       constexpr size_t kBufferSize = sizeof(int64_t);
       const int64_t num_devices = communicator_->size();
       const int64_t rank = communicator_->deviceId();
    +
    +  NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(rank));
    +
       void* d_ptr;
       NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_ptr, kBufferSize));
       const int64_t value = rank;
    @@ -75,6 +81,9 @@ TEST_F(IpcTest, IpcMemHandle) {
     }
    
     TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) {
    +  if (communicator_->size() == 1) {
    +    GTEST_SKIP() << "Skipping test for single device";
    +  }
     #ifdef NVFUSER_DISTRIBUTED
       // TL;DR: We can do pointer arithmetic on the importer side. IOW, the pointer
       // can be used as a regular pointer on the importer side.
    @@ -84,6 +93,9 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) {
       const int64_t num_devices = communicator_->size();
       const int64_t rank = communicator_->deviceId();
       const int64_t peer_rank = (rank + 1) % num_devices;
    +
    +  NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(rank));
    +
       void* d_ptr;
       NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_ptr, kBufferSize));
    
    @@ -125,6 +137,9 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) {
     }
    
     TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) {
    +  if (communicator_->size() == 1) {
    +    GTEST_SKIP() << "Skipping test for single device";
    +  }
     #ifdef NVFUSER_DISTRIBUTED
       // TL;DR: We CANNOT do pointer arithmetic on the exporter side! The IPC handle
       // points to the beginning of the allocated buffer.
    @@ -134,6 +149,9 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) {
       const int64_t num_devices = communicator_->size();
       const int64_t rank = communicator_->deviceId();
       const int64_t peer_rank = (rank + 1) % num_devices;
    +
    +  NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(rank));
    +
       int64_t* d_ptr;
       NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_ptr, kBufferSize));
    

    @wujingyue wujingyue changed the base branch from wjy/rollforward to main April 15, 2025 20:55
    @samnordmann samnordmann merged commit 91b1801 into main Apr 16, 2025
    52 of 53 checks passed
    @samnordmann samnordmann deleted the cuda_ipc_tuto branch April 16, 2025 14:55
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants