1111#include < utility>
1212
1313#include < ATen/cuda/CUDAContext.h>
14+ #include < ATen/cuda/CUDAGraph.h>
1415#include < c10/core/DeviceType.h>
1516#include < c10/cuda/CUDAAllocatorConfig.h>
17+ #include < c10/cuda/CUDAException.h>
1618#include < c10/cuda/CUDAGraphsC10Utils.h>
1719#include < c10/cuda/CUDAGuard.h>
1820#include < c10/util/Exception.h>
@@ -221,6 +223,65 @@ std::string getExceptionMsgFromExceptionPtr(
221223 }
222224}
223225
226+ #ifdef USE_ROCM
227+ // Indicates that we're in the watchdog's event-query phase. This allows ROCm
228+ // workaround behavior to be applied only to watchdog-side queries, while
229+ // preserving existing behavior for user/main-thread `WorkNCCL::isCompleted()`
230+ // and `wait()` calls.
231+ thread_local bool g_in_rocm_watchdog_event_query_context = false ;
232+
233+ struct RocmWatchdogEventQueryContextGuard {
234+ RocmWatchdogEventQueryContextGuard ()
235+ : previous_(g_in_rocm_watchdog_event_query_context) {
236+ g_in_rocm_watchdog_event_query_context = true ;
237+ }
238+ ~RocmWatchdogEventQueryContextGuard () {
239+ g_in_rocm_watchdog_event_query_context = previous_;
240+ }
241+
242+ private:
243+ bool previous_;
244+ };
245+ #endif // USE_ROCM
246+
247+ #ifdef USE_ROCM
248+ // Watchdog-side cudaEventQuery workaround for HIP runtimes without the
249+ // capture-mode fix.
250+ // TODO: Remove once all supported runtimes include
251+ // https://github.com/ROCm/rocm-systems/pull/3176
252+ bool queryEventWithRocmWatchdogCaptureWorkaround (
253+ const std::shared_ptr<at::cuda::CUDAEvent>& event) {
254+ if (!event->isCreated ()) {
255+ return true ;
256+ }
257+
258+ // Must unconditionally return false here during watchdog + active capture:
259+ // on affected HIP runtimes, even calling cudaEventQuery from the watchdog
260+ // thread while another thread has GLOBAL capture active can invalidate that
261+ // capture and cause downstream failures. Skip the query entirely and report
262+ // "not complete yet"; the watchdog will re-poll once capture ends. Timeout
263+ // enforcement is also deferred during this window (see the
264+ // is_graph_capture_active() gate in the watchdog loop).
265+ if (g_in_rocm_watchdog_event_query_context &&
266+ at::cuda::is_graph_capture_active ()) {
267+ return false ;
268+ }
269+
270+ const cudaError_t err =
271+ C10_CUDA_ERROR_HANDLED (cudaEventQuery (event->event ()));
272+ if (err == cudaSuccess) {
273+ return true ;
274+ } else if (err != cudaErrorNotReady) {
275+ C10_CUDA_CHECK (err);
276+ } else {
277+ // ignore and clear the error if not ready
278+ (void )cudaGetLastError ();
279+ }
280+
281+ return false ;
282+ }
283+ #endif // USE_ROCM
284+
224285inline void errorIfCapturingNonCapturableNCCL (c10::cuda::CaptureStatus status) {
225286 // parentheses avoid some compiler warnings
226287 static const uint64_t min_version =
@@ -644,7 +705,11 @@ bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const {
644705 return false ;
645706 }
646707 // Checking the work's corresponding CUDA event's status
708+ #ifdef USE_ROCM
709+ if (!queryEventWithRocmWatchdogCaptureWorkaround (ncclStartEvent_)) {
710+ #else
647711 if (!ncclStartEvent_->query ()) {
712+ #endif
648713 return false ;
649714 }
650715 return true ;
@@ -657,7 +722,11 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
657722 // hang if another thread is holding the CUDA global context lock. For
658723 // example, when doing a `cudaDeviceSynchronize` or even
659724 // `cudaStreamSynchronize`.
725+ #ifdef USE_ROCM
726+ if (!queryEventWithRocmWatchdogCaptureWorkaround (ncclEndEvent_)) {
727+ #else
660728 if (!ncclEndEvent_->query ()) {
729+ #endif
661730 return false ;
662731 }
663732 return true ;
@@ -2291,9 +2360,23 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
22912360 }
22922361 }
22932362
2294- // Then check if work has timed out
2295- // Skip if work has encountered an error
2296- bool timedout = !work.exception () && work.checkTimeout ();
2363+ // Then check if work has timed out.
2364+ // Skip if work has encountered an error.
2365+
2366+ bool timedout = false ;
2367+ #ifdef USE_ROCM
2368+ // On ROCm, watchdog event queries may be intentionally skipped during
2369+ // active graph capture to avoid HIP runtime capture invalidation.
2370+ // In that window, timeout checks can report false positives for
2371+ // otherwise-complete work, so we defer timeout enforcement.
2372+ // TODO: Remove once all supported HIP runtimes include:
2373+ // https://github.com/ROCm/clr/pull/3176
2374+ if (!at::cuda::is_graph_capture_active ()) {
2375+ timedout = !work.exception () && work.checkTimeout ();
2376+ }
2377+ #else
2378+ timedout = !work.exception () && work.checkTimeout ();
2379+ #endif
22972380
22982381 // Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is
22992382 // turned on; otherwise, run() is no-op)
@@ -2358,7 +2441,11 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
23582441 // allow watchdog to do an event query on a side thread
23592442 at::cuda::CUDAGuard device_guard (work.ncclEndEvent_ ->device_index ());
23602443 at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal};
2361-
2444+ #ifdef USE_ROCM
2445+ // Mark this thread/scope as watchdog event-query context so the ROCm
2446+ // workaround applies only here (not to main-thread wait()/isCompleted()).
2447+ RocmWatchdogEventQueryContextGuard watchdog_event_query_context_guard;
2448+ #endif
23622449 // a work could be started but not completed, so we should not update
23632450 // lastStartedSeq and lastStartedOpName if the work state is checked
23642451 // multiple times after the start
0 commit comments