Skip to content

[CudaIpc 2/3]: Ipc handle exchange#3910

Merged
samnordmann merged 27 commits intomainfrom
ipc_handle_infra
Apr 14, 2025
Merged

[CudaIpc 2/3]: Ipc handle exchange#3910
samnordmann merged 27 commits intomainfrom
ipc_handle_infra

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Feb 17, 2025

On top of

prerequesite to:

What

  • Set up the infrastructure needed for ipc handle exchange and caching
  • Add an Expr node hir::ShareMemHandles to represent this op. We cannot embed the op in the Send/Recv semantics because we need to group the handle exchange between matching sends and recv to avoid deadlocks

How

Most of the implementation is in multidevice/ipc_handle.cpp

  • Define the class IpcHandle representing the ipc handle that is exchanged. This class is supplemented with a semaphore, which is a local cuda buffer allocated on the exporter's device.
  • Define IpcHandleCache which handles exchanging and caching the ipc handles. Caching is made on with respect to a combination of runtime and symbolic ingredients: (runtime peer, at::Tensor, Expr*). This caching allows to have arbitrary number of p2p comms between pairs of ranks.

@github-actions
Copy link

github-actions bot commented Feb 17, 2025

Review updated until commit 3957f14

Description

  • Added infrastructure for IPC handle exchange and caching.

  • Introduced ShareMemHandles Expr node for representing handle exchange.

  • Implemented IpcHandle and IpcHandleCache for managing CUDA IPC handles.

  • Updated HostIrEvaluator to handle ShareMemHandles and P2PCommunication with CUDA backend.


Changes walkthrough 📝

Relevant files
Enhancement
14 files
executor.cpp
Added handling for ShareMemHandles and CUDA P2PCommunication.
+46/-16 
host_ir.cpp
Introduced ShareMemHandles Expr node.                                       
+28/-0   
communicator.cpp
Updated CommunicatorBackend to include CUDA.                         
+2/-13   
cuda_p2p.cpp
Added CUDA P2P communication functions.                                   
+70/-0   
ipc_handle.cpp
Implemented IpcHandle and IpcHandleCache for IPC management.
+160/-0 
fusion_kernel_runtime.cpp
Updated HostIrEvaluator instantiation.                                     
+1/-1     
dispatch.h
Added ShareMemHandles to dispatch.                                             
+2/-1     
driver_api.h
Added CUDA driver API functions.                                                 
+3/-1     
executor.h
Added ShareMemHandles handling in HostIrEvaluator.             
+3/-0     
host_ir.h
Introduced ShareMemHandles class.                                               
+25/-0   
communicator.h
Added TCPStore access and updated CommunicatorBackend.     
+4/-3     
cuda_p2p.h
Added CUDA P2P communication declarations.                             
+22/-0   
ipc_handle.h
Introduced IpcHandle and IpcHandleCache classes.                 
+161/-0 
multidevice.h
Updated CommunicatorBackend to include CUDA.                         
+3/-0     
Tests
2 files
test_multidevice_communications.cpp
Added test for CUDA P2P communication.                                     
+71/-0   
test_multidevice_host_ir.cpp
Added test for ShareMemHandles.                                                   
+71/-0   
Configuration changes
3 files
.gitmodules
Removed gloo submodule.                                                                   
+0/-3     
CMakeLists.txt
Added new source files and removed gloo include.                 
+2/-1     
gloo
Removed gloo submodule.                                                                   
+0/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Error Handling

Ensure that all CUDA API calls have proper error handling and that the error messages are informative.

  size_t psize = 0;
  NVFUSER_CUDA_SAFE_CALL(cuMemGetAddressRange(
      (CUdeviceptr*)&base_address_, &psize, (CUdeviceptr)ptr_));
  offset_from_base_address_ = static_cast<int64_t>(
      static_cast<uint8_t*>(ptr_) - static_cast<uint8_t*>(base_address_));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&ipc_handle_, tensor.data_ptr()));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaMalloc((void**)&semaphore_, sizeof(IpcSemaphore)));
  static_assert(
      sizeof(IpcSemaphore) == sizeof(int),
      "IpcSemaphore must be same size as int");
  NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(
      (void*)semaphore_, (int)IpcSemaphore::kReady, sizeof(IpcSemaphore)));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&semaphore_ipc_handle_, semaphore_));
}

IpcHandle::IpcHandle(std::vector<uint8_t> data) {
  const IpcHandle& imported_buffer = fromBytes<IpcHandle>(data);

  offset_from_base_address_ = imported_buffer.offset_from_base_address_;
  ipc_handle_ = imported_buffer.ipc_handle_;
  semaphore_ipc_handle_ = imported_buffer.semaphore_ipc_handle_;
  rank_ = imported_buffer.rank_;

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      &base_address_, ipc_handle_, cudaIpcMemLazyEnablePeerAccess));
  ptr_ = (void*)((uint8_t*)base_address_ + offset_from_base_address_);

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      (void**)&semaphore_,
      semaphore_ipc_handle_,
      cudaIpcMemLazyEnablePeerAccess));
}

IpcHandle::~IpcHandle() {
  if (rank_ == Communicator::getInstance().deviceId()) {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaFree((void*)semaphore_));
  } else {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(base_address_));
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle((void*)semaphore_));
  }
}

// retrieves a key for the TCP store corresponding to a `communication` and the
// exporter `rank`
std::string IpcHandleCache::getTcpStoreKey(
    P2PCommunication* communication,
    int64_t rank) const {
  const int64_t my_rank = Communicator::getInstance().deviceId();
  const int64_t peer =
      expr_evaluator_.evaluate(communication->peer()).as<int64_t>();
  const int64_t src =
      communication->type() == P2PCommunicationType::SEND ? my_rank : peer;
  const int64_t dst =
      communication->type() == P2PCommunicationType::SEND ? peer : my_rank;

  return "nvfuser_ipc_handle_info_P2PComm_dst=" + std::to_string(dst) +
      "_src=" + std::to_string(src) + "_rank=" + std::to_string(rank);
}

void IpcHandleCache::exchangeHandles(
    const std::vector<P2PCommunication*>& communications) {
#ifdef NVFUSER_DISTRIBUTED
  Communicator* communicator = &Communicator::getInstance();
  const int64_t my_rank = communicator->deviceId();

  std::vector<P2PCommunication*> non_cached_communications;
  for (auto communication : communications) {
    NVF_ERROR(
        expr_evaluator_.evaluate(communication->peer()).as<int64_t>() !=
            my_rank,
        "send to self not supported");
    if (find(communication) != nullptr) {
      continue;
    }
    non_cached_communications.push_back(communication);
  }

  // put memhandles to TCP store
  std::unordered_map<P2PCommunication*, std::unique_ptr<IpcHandle>>
      local_ipc_handles;
  auto store = communicator->getTcpStore();
  for (P2PCommunication* communication : non_cached_communications) {
    at::Tensor tensor =
        expr_evaluator_.evaluate(communication->buffer()).as<at::Tensor>();
    NVF_ERROR(
        tensor.is_contiguous(), "IpcHandle only supports contiguous tensors");
    auto buffer_handle = std::make_unique<IpcHandle>(tensor);
    auto key = getTcpStoreKey(communication, my_rank);
    // TODO: use multiSet
    store->set(key, toBytes(*buffer_handle));
    local_ipc_handles.emplace(communication, std::move(buffer_handle));
  }

  // barrier to ensure all ranks have pushed their memhandles to the store
  // TODO: precisely select what ranks need to wait on that barrier.
  communicator->barrier();

  // get memhandles from TCP store
  for (P2PCommunication* communication : non_cached_communications) {
    const int64_t peer =
        expr_evaluator_.evaluate(communication->peer()).as<int64_t>();
    std::string key = getTcpStoreKey(communication, peer);
    NVF_ERROR(
        store->check({key}),
        "key ",
        key,
        " not found in store at rank ",
        my_rank);
    // TODO: use multiGet
    auto peer_ipc_handle = std::make_unique<IpcHandle>(store->get(key));
    store->deleteKey(key);
    auto& local_ipc_handle = local_ipc_handles.at(communication);

    auto ipc_handles = std::make_unique<P2pIpcHandle>(
        std::move(local_ipc_handle), std::move(peer_ipc_handle));

    insert(communication, std::move(ipc_handles));
  }
#else // NVFUSER_DISTRIBUTED
  NVF_ERROR(false, "NVFUSER_DISTRIBUTED is not defined");
#endif // NVFUSER_DISTRIBUTED
}
Memory Management

Verify that all allocated memory is properly freed and that there are no memory leaks.

  size_t psize = 0;
  NVFUSER_CUDA_SAFE_CALL(cuMemGetAddressRange(
      (CUdeviceptr*)&base_address_, &psize, (CUdeviceptr)ptr_));
  offset_from_base_address_ = static_cast<int64_t>(
      static_cast<uint8_t*>(ptr_) - static_cast<uint8_t*>(base_address_));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&ipc_handle_, tensor.data_ptr()));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaMalloc((void**)&semaphore_, sizeof(IpcSemaphore)));
  static_assert(
      sizeof(IpcSemaphore) == sizeof(int),
      "IpcSemaphore must be same size as int");
  NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(
      (void*)semaphore_, (int)IpcSemaphore::kReady, sizeof(IpcSemaphore)));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&semaphore_ipc_handle_, semaphore_));
}

IpcHandle::IpcHandle(std::vector<uint8_t> data) {
  const IpcHandle& imported_buffer = fromBytes<IpcHandle>(data);

  offset_from_base_address_ = imported_buffer.offset_from_base_address_;
  ipc_handle_ = imported_buffer.ipc_handle_;
  semaphore_ipc_handle_ = imported_buffer.semaphore_ipc_handle_;
  rank_ = imported_buffer.rank_;

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      &base_address_, ipc_handle_, cudaIpcMemLazyEnablePeerAccess));
  ptr_ = (void*)((uint8_t*)base_address_ + offset_from_base_address_);

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      (void**)&semaphore_,
      semaphore_ipc_handle_,
      cudaIpcMemLazyEnablePeerAccess));
}

IpcHandle::~IpcHandle() {
  if (rank_ == Communicator::getInstance().deviceId()) {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaFree((void*)semaphore_));
  } else {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(base_address_));
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle((void*)semaphore_));
  }
}

// retrieves a key for the TCP store corresponding to a `communication` and the
// exporter `rank`
std::string IpcHandleCache::getTcpStoreKey(
    P2PCommunication* communication,
    int64_t rank) const {
  const int64_t my_rank = Communicator::getInstance().deviceId();
  const int64_t peer =
      expr_evaluator_.evaluate(communication->peer()).as<int64_t>();
  const int64_t src =
      communication->type() == P2PCommunicationType::SEND ? my_rank : peer;
  const int64_t dst =
      communication->type() == P2PCommunicationType::SEND ? peer : my_rank;

  return "nvfuser_ipc_handle_info_P2PComm_dst=" + std::to_string(dst) +
      "_src=" + std::to_string(src) + "_rank=" + std::to_string(rank);
}

void IpcHandleCache::exchangeHandles(
    const std::vector<P2PCommunication*>& communications) {
#ifdef NVFUSER_DISTRIBUTED
  Communicator* communicator = &Communicator::getInstance();
  const int64_t my_rank = communicator->deviceId();

  std::vector<P2PCommunication*> non_cached_communications;
  for (auto communication : communications) {
    NVF_ERROR(
        expr_evaluator_.evaluate(communication->peer()).as<int64_t>() !=
            my_rank,
        "send to self not supported");
    if (find(communication) != nullptr) {
      continue;
    }
    non_cached_communications.push_back(communication);
  }

  // put memhandles to TCP store
  std::unordered_map<P2PCommunication*, std::unique_ptr<IpcHandle>>
      local_ipc_handles;
  auto store = communicator->getTcpStore();
  for (P2PCommunication* communication : non_cached_communications) {
    at::Tensor tensor =
        expr_evaluator_.evaluate(communication->buffer()).as<at::Tensor>();
    NVF_ERROR(
        tensor.is_contiguous(), "IpcHandle only supports contiguous tensors");
    auto buffer_handle = std::make_unique<IpcHandle>(tensor);
    auto key = getTcpStoreKey(communication, my_rank);
    // TODO: use multiSet
    store->set(key, toBytes(*buffer_handle));
    local_ipc_handles.emplace(communication, std::move(buffer_handle));
  }

  // barrier to ensure all ranks have pushed their memhandles to the store
  // TODO: precisely select what ranks need to wait on that barrier.
  communicator->barrier();

  // get memhandles from TCP store
  for (P2PCommunication* communication : non_cached_communications) {
    const int64_t peer =
        expr_evaluator_.evaluate(communication->peer()).as<int64_t>();
    std::string key = getTcpStoreKey(communication, peer);
    NVF_ERROR(
        store->check({key}),
        "key ",
        key,
        " not found in store at rank ",
        my_rank);
    // TODO: use multiGet
    auto peer_ipc_handle = std::make_unique<IpcHandle>(store->get(key));
    store->deleteKey(key);
    auto& local_ipc_handle = local_ipc_handles.at(communication);

    auto ipc_handles = std::make_unique<P2pIpcHandle>(
        std::move(local_ipc_handle), std::move(peer_ipc_handle));

    insert(communication, std::move(ipc_handles));
  }
#else // NVFUSER_DISTRIBUTED
  NVF_ERROR(false, "NVFUSER_DISTRIBUTED is not defined");
#endif // NVFUSER_DISTRIBUTED
}
Performance Considerations

Evaluate the performance impact of using cudaIpcGetMemHandle and cudaIpcOpenMemHandle in a distributed setting and consider optimizing for latency and throughput.

  size_t psize = 0;
  NVFUSER_CUDA_SAFE_CALL(cuMemGetAddressRange(
      (CUdeviceptr*)&base_address_, &psize, (CUdeviceptr)ptr_));
  offset_from_base_address_ = static_cast<int64_t>(
      static_cast<uint8_t*>(ptr_) - static_cast<uint8_t*>(base_address_));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&ipc_handle_, tensor.data_ptr()));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaMalloc((void**)&semaphore_, sizeof(IpcSemaphore)));
  static_assert(
      sizeof(IpcSemaphore) == sizeof(int),
      "IpcSemaphore must be same size as int");
  NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(
      (void*)semaphore_, (int)IpcSemaphore::kReady, sizeof(IpcSemaphore)));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&semaphore_ipc_handle_, semaphore_));
}

IpcHandle::IpcHandle(std::vector<uint8_t> data) {
  const IpcHandle& imported_buffer = fromBytes<IpcHandle>(data);

  offset_from_base_address_ = imported_buffer.offset_from_base_address_;
  ipc_handle_ = imported_buffer.ipc_handle_;
  semaphore_ipc_handle_ = imported_buffer.semaphore_ipc_handle_;
  rank_ = imported_buffer.rank_;

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      &base_address_, ipc_handle_, cudaIpcMemLazyEnablePeerAccess));
  ptr_ = (void*)((uint8_t*)base_address_ + offset_from_base_address_);

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      (void**)&semaphore_,
      semaphore_ipc_handle_,
      cudaIpcMemLazyEnablePeerAccess));
}

IpcHandle::~IpcHandle() {
  if (rank_ == Communicator::getInstance().deviceId()) {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaFree((void*)semaphore_));
  } else {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(base_address_));
    NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle((void*)semaphore_));
  }
}

// retrieves a key for the TCP store corresponding to a `communication` and the
// exporter `rank`
std::string IpcHandleCache::getTcpStoreKey(
    P2PCommunication* communication,
    int64_t rank) const {
  const int64_t my_rank = Communicator::getInstance().deviceId();
  const int64_t peer =
      expr_evaluator_.evaluate(communication->peer()).as<int64_t>();
  const int64_t src =
      communication->type() == P2PCommunicationType::SEND ? my_rank : peer;
  const int64_t dst =
      communication->type() == P2PCommunicationType::SEND ? peer : my_rank;

  return "nvfuser_ipc_handle_info_P2PComm_dst=" + std::to_string(dst) +
      "_src=" + std::to_string(src) + "_rank=" + std::to_string(rank);
}

void IpcHandleCache::exchangeHandles(
    const std::vector<P2PCommunication*>& communications) {
#ifdef NVFUSER_DISTRIBUTED
  Communicator* communicator = &Communicator::getInstance();
  const int64_t my_rank = communicator->deviceId();

  std::vector<P2PCommunication*> non_cached_communications;
  for (auto communication : communications) {
    NVF_ERROR(
        expr_evaluator_.evaluate(communication->peer()).as<int64_t>() !=
            my_rank,
        "send to self not supported");
    if (find(communication) != nullptr) {
      continue;
    }
    non_cached_communications.push_back(communication);
  }

  // put memhandles to TCP store
  std::unordered_map<P2PCommunication*, std::unique_ptr<IpcHandle>>
      local_ipc_handles;
  auto store = communicator->getTcpStore();
  for (P2PCommunication* communication : non_cached_communications) {
    at::Tensor tensor =
        expr_evaluator_.evaluate(communication->buffer()).as<at::Tensor>();
    NVF_ERROR(
        tensor.is_contiguous(), "IpcHandle only supports contiguous tensors");
    auto buffer_handle = std::make_unique<IpcHandle>(tensor);
    auto key = getTcpStoreKey(communication, my_rank);
    // TODO: use multiSet
    store->set(key, toBytes(*buffer_handle));
    local_ipc_handles.emplace(communication, std::move(buffer_handle));
  }

  // barrier to ensure all ranks have pushed their memhandles to the store
  // TODO: precisely select what ranks need to wait on that barrier.
  communicator->barrier();

  // get memhandles from TCP store
  for (P2PCommunication* communication : non_cached_communications) {
    const int64_t peer =
        expr_evaluator_.evaluate(communication->peer()).as<int64_t>();
    std::string key = getTcpStoreKey(communication, peer);
    NVF_ERROR(
        store->check({key}),
        "key ",
        key,
        " not found in store at rank ",
        my_rank);
    // TODO: use multiGet
    auto peer_ipc_handle = std::make_unique<IpcHandle>(store->get(key));
    store->deleteKey(key);
    auto& local_ipc_handle = local_ipc_handles.at(communication);

    auto ipc_handles = std::make_unique<P2pIpcHandle>(
        std::move(local_ipc_handle), std::move(peer_ipc_handle));

    insert(communication, std::move(ipc_handles));
  }
#else // NVFUSER_DISTRIBUTED
  NVF_ERROR(false, "NVFUSER_DISTRIBUTED is not defined");
#endif // NVFUSER_DISTRIBUTED
}

@samnordmann samnordmann changed the title Ipc handle infra [CudaIpc 2/3]: Ipc handle exchange Feb 17, 2025
@samnordmann
Copy link
Collaborator Author

!test

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To help me understand this better, is it possible to add some unit tests to this PR? I tend to review a PR starting from tests unless it's trivial.

Base automatically changed from add_backend_type_to_p2p_comm to main February 21, 2025 13:04
@samnordmann
Copy link
Collaborator Author

To help me understand this better, is it possible to add some unit tests to this PR? I tend to review a PR starting from tests unless it's trivial.

you can look at the test in #3911
Lmk if you'd still want to add a test to the present PR. But I'm afraid such a test would quite artifical

@samnordmann
Copy link
Collaborator Author

To help me understand this better, is it possible to add some unit tests to this PR? I tend to review a PR starting from tests unless it's trivial.

you can look at the test in #3911 Lmk if you'd still want to add a test to the present PR. But I'm afraid such a test would quite artifical

I finally added a test

@samnordmann
Copy link
Collaborator Author

!test

CMakeLists.txt Outdated
${LIBCUPTI}
${TORCH_LIBRARIES}
dl
cuda
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed for cuMemGetAddressRange, otherwise getting an "invalid context" error. But according to #3907 there is clearly an issue with how we load driver API in general, so I'm linking to it directly in the meantime

if (isOptionEnabled(EnableOption::HostIrLowering)) {
hie_ = std::make_unique<hir::HostIrEvaluator>(
hir::HostIrEvaluator(std::move(hic)));
hie_ = std::make_unique<hir::HostIrEvaluator>(std::move(hic));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @nsarka
My compiler complains otherwise

// all ranks set `send_tensor`
send_tensor.copy_(generate_tensor(repetition, my_rank));
torch::cuda::synchronize();
communicator_->barrier();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Collaborator Author

@samnordmann samnordmann Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move it to after the cudaMemcpy. Because it is a "put" algorithm: we need to wait that the sender side has finished writing before validating.

If instead we write a "get" algorithm, the barrier would have been at the right place, and the justification would have been: we need to wait that the send side has finished setting-up its buffer before reading.

In any case we need a synchronization across iterations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the semaphore is supposed to take care of this sender/receiver synchronization and therefore avoids the barrier. Am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the semaphore is supposed to take care of this sender/receiver synchronization and therefore avoids the barrier. Am I missing something?

You are right, semaphores are used to sync in the p2p primitive that are implemented in the next PR. But, here, we have a standalone unit testing sharing the ipc handles, as per your request, so some synchronization needs to be added

}

private:
using KeyType = std::tuple<int64_t, at::Tensor, P2PCommunication*>;
Copy link
Collaborator

@wujingyue wujingyue Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite unsure about the key type. My impression from #3912 has been that IPC handle is registered per tensor not per communication. For example, when rank 0 sends the same buffer to rank 1's buffer1 and buffer2, can't rank 0 use the same IPC handle? (Of course, it would have to maintain some ref counts so the buffer is not prematurely deallocated before both reads are done)

Copy link
Collaborator Author

@samnordmann samnordmann Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need semaphores for the synchronization, one per communication, thus this key type

Copy link
Collaborator Author

@samnordmann samnordmann Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the semaphores, we could think that we could reuse the same buffer's ipc handles for several P2P using the same buffer. But I don't think that is a good idea. It would only save us some cudaIpcGetMemHandle, which is a really minor improvement, but wouldn't save us all the remaining, semaphore, set/get the TCP store, the barrier etc.

And we still need to be consistent with the symmetry assumption. If rank 0 sends the same buffer to rank 1 and rank 2, we need two exchanges (0/1 and 0/2), therefore a way to not hitting the cache even though the buffer is the same. Same for 0 and 1 involved in two p2p communications, e.g., rank 0 sends buffer a to rank 1's buffer b and rank 0 sends buffer a to rank 1's buffer c

Does it make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure what you have here is a valid solution. Lots of my questions came from that I'm new to CUDA IPC and I'm trying to figure out the first principles. Therefore, in addition to one solution, I'm trying to understand the design space and why certain solutions are preferred.

Re: semaphores. Yes, I understood that one semaphore per P2P communication is a valid solution. I also imagine a semaphore can be extended to deal with multiple senders or receivers, because there are "counting semaphores" which hold a count. Do people consider this for implementing collectives like allgather and allreduce?

Re: symmetry assumption. I'm sure you mentioned this somewhere and I forgot what it is and what it buys us. Do you have a reference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure what you have here is a valid solution. Lots of my questions came from that I'm new to CUDA IPC and I'm trying to figure out the first principles. Therefore, in addition to one solution, I'm trying to understand the design space and why certain solutions are preferred.

No problem, I'm also happy to brainstorm ! :)

Re: semaphores. Yes, I understood that one semaphore per P2P communication is a valid solution. I also imagine a semaphore can be extended to deal with multiple senders or receivers, because there are "counting semaphores" which hold a count. Do people consider this for implementing collectives like allgather and allreduce?

"Counters" are definitely possible, but have two problems:

  • cuStream ops do not have a read-and-write primitive, so we cannot easily implement incrementing a counter
  • There is even less "atomic add" operations, and thus we would need to handle race conditions between ranks anyway.

Besides this, I can't see immediate benefit from using less semaphore allocations, if not a slightly reduced memory footprint.

Re: symmetry assumption. I'm sure you mentioned this somewhere and I forgot what it is and what it buys us. Do you have a reference?

I don't have a reference, I am using my own definition here. Btw this assumption is implicit in this pr and we need to make it explicit and more robust in the future. What I mean in this PR by symmetry assumption is: "if we hit the cache locally, we hit the cache globally". IOW, we only check in the local cache if we can use the IpcHandle, without a further interprocess synchronization.

This assumption is not necessarily true if, e.g., across iterations, one rank changes the buffer and the other doesn't.
However, this assumption can be enforced in nvFuser, at least for internal buffers that we allocate ourselves (we will probably need to add a new allocation attribute for that).

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! While I still have several questions on the big picture that we can keep discussing in the PR, I believe it tries to implement fundamental building blocks that will enable IPC inside nvFuser.

// all ranks set `send_tensor`
send_tensor.copy_(generate_tensor(repetition, my_rank));
torch::cuda::synchronize();
communicator_->barrier();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the semaphore is supposed to take care of this sender/receiver synchronization and therefore avoids the barrier. Am I missing something?


const ExpressionEvaluator& expr_evaluator_;
std::unordered_map<KeyType, std::unique_ptr<P2pIpcHandle>, KeyHash, KeyEqual>
handles_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A non-blocking question: Apparently, this is not deallocated until the end of the program so all handles and buffers live until the end. In the future, do we plan to have HostIrEvaluator::handle(P2PCommunication*) to stream-wait for the semaphore and a later Deallocate IR to deallocate the buffer? (That would make sense to me.)

Copy link
Collaborator Author

@samnordmann samnordmann Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this comment earlier, sorry

all handles and buffers live until the end

Regarding Ipc handles and semaphore: yes, their lifetime is owned by the P2pIpcHandle object.

Regarding data buffers, IIUC, in the current version, I think the answer is no. Indeed, the buffer's lifetime is managed by pytorch at::Tensor, and the Ipc handles do not hold any reference of the at::Tensor.
However, we might want to fix this to ensure that the buffer is still live before we close the handle, since the doc says:

Calling cudaFree on an exported memory region before calling cudaIpcCloseMemHandle in the importing context will result in undefined behavior.

Letting the IpcHandle hold the at::Tensor should fix this

a later Deallocate IR to deallocate the buffer

yes, that's also what I have in mind. The Deallocate IR should take care of cleaning up not only the data buffer but also any potential Ipc Handle pointing to it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Letting the IpcHandle hold the at::Tensor should fix this

I added that

P2PCommunicationType::SEND,
send_tv,
IrBuilder::create<Val>(send_peer),
CommunicatorBackend::kNccl);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be the CUDA backend?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet -- it will be in the next PR
In this test the backend is meaningless, since we don't use HostIrEvaluator. We just need to create the Expr* to feed IpcHandleCache::exchangeHandles

// to be exported by batch (thus the function taking a vector of
// P2PCommunication*) to improve performance and to avoid creating deadlocks
// when imports and exports order differ accross ranks.
void exchangeHandles(const std::vector<P2PCommunication*>& communications);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I understood how it's used in the unit test, it's unclear when/how we will call this in practice. But I'm happy to defer this to following PRs.

When a P2PCommunication is in a for loop, it's no longer possible to exchangeHandles at the beginning of the program because peer and tensor depend on the loop index. Therefore, do you expect this to be called at the beginning of each control flow scope and with all P2PCommunications that are guaranteed to happen in that scope?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully it wil become clearer in #3911

When a P2PCommunication is in a for loop, it's no longer possible to exchangeHandles at the beginning of the program because peer and tensor depend on the loop index.

It is not a problem ; actually, we precisely intend to use exchangeHandles inside a for Loop. That is the reason why the caching is not only based on symbolic P2PCommunication*, but also on runtime evaluation (at::Tensor buffer, int64_t peer) which a evaluated at each for-loop iteration.

Therefore, do you expect this to be called at the beginning of each control flow scope and with all P2PCommunications that are guaranteed to happen in that scope?

No, each iteration can trigger a new exchange if the cache is missed.

}

private:
using KeyType = std::tuple<int64_t, at::Tensor, P2PCommunication*>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure what you have here is a valid solution. Lots of my questions came from that I'm new to CUDA IPC and I'm trying to figure out the first principles. Therefore, in addition to one solution, I'm trying to understand the design space and why certain solutions are preferred.

Re: semaphores. Yes, I understood that one semaphore per P2P communication is a valid solution. I also imagine a semaphore can be extended to deal with multiple senders or receivers, because there are "counting semaphores" which hold a count. Do people consider this for implementing collectives like allgather and allreduce?

Re: symmetry assumption. I'm sure you mentioned this somewhere and I forgot what it is and what it buys us. Do you have a reference?

@samnordmann
Copy link
Collaborator Author

merging is blocked by the lazy-loading issue, but this time with cuMemGetAddressRange

On top of:
- #3910
- #3909
- #3908

Pending on issue:
- #3907
@samnordmann
Copy link
Collaborator Author

!test

1 similar comment
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!build

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann merged commit c5636f8 into main Apr 14, 2025
50 of 51 checks passed
@samnordmann samnordmann deleted the ipc_handle_infra branch April 14, 2025 12:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants