Skip to content

Commit 23503e4

Browse files
chinmaydk99AMD
authored andcommitted
Avoid watchdog polling during graph capture in ROCm
1 parent e274aff commit 23503e4

3 files changed

Lines changed: 117 additions & 6 deletions

File tree

aten/src/ATen/cuda/CUDAGraph.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ static bool _cuda_graphs_debug = false;
2222
static std::mutex _currently_capturing_graphs_mutex;
2323
static ska::flat_hash_map<CaptureId_t, CUDAGraph*> _currently_capturing_graphs;
2424

25+
26+
#if defined(USE_ROCM)
27+
// Returns true when at least one CUDAGraph capture is currently active in this
28+
// process. Uses the same mutex-protected capture map as capture lifecycle
29+
// bookkeeping.
30+
bool is_graph_capture_active() {
31+
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
32+
return !_currently_capturing_graphs.empty();
33+
}
34+
#endif // defined(USE_ROCM)
35+
2536
MempoolId_t graph_pool_handle() {
2637
// Sets just the second value, to distinguish it from MempoolId_ts created from
2738
// cudaStreamGetCaptureInfo id_s in capture_begin.
@@ -140,15 +151,18 @@ void CUDAGraph::capture_end() {
140151
TORCH_CHECK(stream.stream() == capture_stream_.stream(),
141152
"Capture must end on the same stream it began on.");
142153

143-
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
144-
154+
// Capture is over once cudaStreamEndCapture returns (success or failure).
155+
// Clear bookkeeping before propagating the return status so watchdog-side
156+
// checks cannot observe stale "capture active" state on error paths.
157+
cudaError_t endCaptureErr = cudaStreamEndCapture(capture_stream_, &graph_);
145158
{
146159
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
147160
TORCH_CHECK(
148161
_currently_capturing_graphs.count(capture_id_),
149162
"capture_end() called before capture_begin().");
150163
_currently_capturing_graphs.erase(capture_id_);
151164
}
165+
AT_CUDA_CHECK(endCaptureErr);
152166

153167
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
154168
at::getHostAllocator(at::kCUDA)->end_allocate_to_pool(mempool_id_);

aten/src/ATen/cuda/CUDAGraph.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ namespace cuda {
3030
// to CUDAGraph::capture_begin
3131
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
3232

33+
// Returns true if any CUDAGraph capture is currently active in this process.
34+
// Used by ProcessGroupNCCL's ROCm watchdog workaround to avoid calling
35+
// hipEventQuery during active capture on HIP runtimes without the
36+
// event-query capture-mode fix (https://github.com/ROCm/clr/pull/3176).
37+
// Not needed on CUDA/NVIDIA where cross-thread event query does not have this
38+
// restriction.
39+
#if defined(USE_ROCM)
40+
TORCH_CUDA_CPP_API bool is_graph_capture_active();
41+
#endif // defined(USE_ROCM)
42+
3343
struct TORCH_CUDA_CPP_API CUDAGraph {
3444
CUDAGraph(bool keep_graph=false);
3545
~CUDAGraph();

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
#include <utility>
1212

1313
#include <ATen/cuda/CUDAContext.h>
14+
#include <ATen/cuda/CUDAGraph.h>
1415
#include <c10/core/DeviceType.h>
1516
#include <c10/cuda/CUDAAllocatorConfig.h>
17+
#include <c10/cuda/CUDAException.h>
1618
#include <c10/cuda/CUDAGraphsC10Utils.h>
1719
#include <c10/cuda/CUDAGuard.h>
1820
#include <c10/util/Exception.h>
@@ -221,6 +223,65 @@ std::string getExceptionMsgFromExceptionPtr(
221223
}
222224
}
223225

226+
#ifdef USE_ROCM
227+
// Indicates that we're in the watchdog's event-query phase. This allows ROCm
228+
// workaround behavior to be applied only to watchdog-side queries, while
229+
// preserving existing behavior for user/main-thread `WorkNCCL::isCompleted()`
230+
// and `wait()` calls.
231+
thread_local bool g_in_rocm_watchdog_event_query_context = false;
232+
233+
struct RocmWatchdogEventQueryContextGuard {
234+
RocmWatchdogEventQueryContextGuard()
235+
: previous_(g_in_rocm_watchdog_event_query_context) {
236+
g_in_rocm_watchdog_event_query_context = true;
237+
}
238+
~RocmWatchdogEventQueryContextGuard() {
239+
g_in_rocm_watchdog_event_query_context = previous_;
240+
}
241+
242+
private:
243+
bool previous_;
244+
};
245+
#endif // USE_ROCM
246+
247+
#ifdef USE_ROCM
248+
// Watchdog-side cudaEventQuery workaround for HIP runtimes without the
249+
// capture-mode fix.
250+
// TODO: Remove once all supported runtimes include
251+
// https://github.com/ROCm/rocm-systems/pull/3176
252+
bool queryEventWithRocmWatchdogCaptureWorkaround(
253+
const std::shared_ptr<at::cuda::CUDAEvent>& event) {
254+
if (!event->isCreated()) {
255+
return true;
256+
}
257+
258+
// Must unconditionally return false here during watchdog + active capture:
259+
// on affected HIP runtimes, even calling cudaEventQuery from the watchdog
260+
// thread while another thread has GLOBAL capture active can invalidate that
261+
// capture and cause downstream failures. Skip the query entirely and report
262+
// "not complete yet"; the watchdog will re-poll once capture ends. Timeout
263+
// enforcement is also deferred during this window (see the
264+
// is_graph_capture_active() gate in the watchdog loop).
265+
if (g_in_rocm_watchdog_event_query_context &&
266+
at::cuda::is_graph_capture_active()) {
267+
return false;
268+
}
269+
270+
const cudaError_t err =
271+
C10_CUDA_ERROR_HANDLED(cudaEventQuery(event->event()));
272+
if (err == cudaSuccess) {
273+
return true;
274+
} else if (err != cudaErrorNotReady) {
275+
C10_CUDA_CHECK(err);
276+
} else {
277+
// ignore and clear the error if not ready
278+
(void)cudaGetLastError();
279+
}
280+
281+
return false;
282+
}
283+
#endif // USE_ROCM
284+
224285
inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
225286
// parentheses avoid some compiler warnings
226287
static const uint64_t min_version =
@@ -644,7 +705,11 @@ bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const {
644705
return false;
645706
}
646707
// Checking the work's corresponding CUDA event's status
708+
#ifdef USE_ROCM
709+
if (!queryEventWithRocmWatchdogCaptureWorkaround(ncclStartEvent_)) {
710+
#else
647711
if (!ncclStartEvent_->query()) {
712+
#endif
648713
return false;
649714
}
650715
return true;
@@ -657,7 +722,11 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
657722
// hang if another thread is holding the CUDA global context lock. For
658723
// example, when doing a `cudaDeviceSynchronize` or even
659724
// `cudaStreamSynchronize`.
725+
#ifdef USE_ROCM
726+
if (!queryEventWithRocmWatchdogCaptureWorkaround(ncclEndEvent_)) {
727+
#else
660728
if (!ncclEndEvent_->query()) {
729+
#endif
661730
return false;
662731
}
663732
return true;
@@ -2291,9 +2360,23 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
22912360
}
22922361
}
22932362

2294-
// Then check if work has timed out
2295-
// Skip if work has encountered an error
2296-
bool timedout = !work.exception() && work.checkTimeout();
2363+
// Then check if work has timed out.
2364+
// Skip if work has encountered an error.
2365+
2366+
bool timedout = false;
2367+
#ifdef USE_ROCM
2368+
// On ROCm, watchdog event queries may be intentionally skipped during
2369+
// active graph capture to avoid HIP runtime capture invalidation.
2370+
// In that window, timeout checks can report false positives for
2371+
// otherwise-complete work, so we defer timeout enforcement.
2372+
// TODO: Remove once all supported HIP runtimes include:
2373+
// https://github.com/ROCm/clr/pull/3176
2374+
if (!at::cuda::is_graph_capture_active()) {
2375+
timedout = !work.exception() && work.checkTimeout();
2376+
}
2377+
#else
2378+
timedout = !work.exception() && work.checkTimeout();
2379+
#endif
22972380

22982381
// Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is
22992382
// turned on; otherwise, run() is no-op)
@@ -2358,7 +2441,11 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
23582441
// allow watchdog to do an event query on a side thread
23592442
at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index());
23602443
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal};
2361-
2444+
#ifdef USE_ROCM
2445+
// Mark this thread/scope as watchdog event-query context so the ROCm
2446+
// workaround applies only here (not to main-thread wait()/isCompleted()).
2447+
RocmWatchdogEventQueryContextGuard watchdog_event_query_context_guard;
2448+
#endif
23622449
// a work could be started but not completed, so we should not update
23632450
// lastStartedSeq and lastStartedOpName if the work state is checked
23642451
// multiple times after the start

0 commit comments

Comments
 (0)