Skip to content

[HIP] Fix hipEventQuery/hipEventSynchronize stream capture mode restrictions#3176

Merged
JeniferC99 merged 1 commit intorelease/rocm-rel-7.2from
users/agodavar/hipEventQuery-capture-mode-fix
Feb 13, 2026
Merged

[HIP] Fix hipEventQuery/hipEventSynchronize stream capture mode restrictions#3176
JeniferC99 merged 1 commit intorelease/rocm-rel-7.2from
users/agodavar/hipEventQuery-capture-mode-fix

Conversation

@anugodavar
Copy link
Copy Markdown
Contributor

@anugodavar anugodavar commented Feb 10, 2026

Motivation

This fix addresses the issue where hipEventQuery and hipEventSynchronize incorrectly returned hipErrorStreamCaptureUnsupported when called from a thread that had switched to RELAXED or THREAD_LOCAL capture mode, even when another thread had GLOBAL capture active.

Technical Details

Added checkEventCaptureRestrictions() helper function to properly handle stream capture mode restrictions for event operations
RELAXED mode threads can now query/sync events during cross-thread captures
THREAD_LOCAL mode threads skip cross-thread GLOBAL capture checks
Events recorded during capture still correctly return hipErrorCapturedEvent
Added comprehensive unit tests for all capture mode combinations

This fix is critical for PyTorch/RCCL integration where watchdog threads need to call hipEventQuery while another thread has GLOBAL capture active.

JIRA ID

SWDEV-579185

Test Plan

NA

Test Result

NA

Submission Checklist

@anugodavar anugodavar changed the base branch from develop to release/rocm-rel-7.2 February 10, 2026 15:07
@github-actions github-actions bot added documentation Improvements or additions to documentation project: amdsmi labels Feb 10, 2026
@anugodavar anugodavar force-pushed the users/agodavar/hipEventQuery-capture-mode-fix branch from 8e0da11 to dcc6234 Compare February 11, 2026 14:14
@anugodavar anugodavar removed request for a team February 11, 2026 14:15
…ictions

This fix addresses the issue where hipEventQuery and hipEventSynchronize
incorrectly returned hipErrorStreamCaptureUnsupported when called from
a thread that had switched to RELAXED or THREAD_LOCAL capture mode,
even when another thread had GLOBAL capture active.

Key changes:
- Added checkEventCaptureRestrictions() helper function to properly handle
  stream capture mode restrictions for event operations
- RELAXED mode threads can now query/sync events during cross-thread captures
- THREAD_LOCAL mode threads skip cross-thread GLOBAL capture checks
- Events recorded during capture still correctly return hipErrorCapturedEvent
- Added comprehensive unit tests for all capture mode combinations

This fix is critical for PyTorch/RCCL integration where watchdog threads
need to call hipEventQuery while another thread has GLOBAL capture active.
@anugodavar anugodavar force-pushed the users/agodavar/hipEventQuery-capture-mode-fix branch from dcc6234 to d0e85c1 Compare February 11, 2026 14:25
@mangupta mangupta removed project: amdsmi documentation Improvements or additions to documentation labels Feb 12, 2026
@JeniferC99
Copy link
Copy Markdown
Collaborator

manual psdb triggered: http://rocm-ci.amd.com/job/compute-psdb-rel-7.2/350

@JeniferC99 JeniferC99 merged commit 99c3292 into release/rocm-rel-7.2 Feb 13, 2026
8 of 65 checks passed
@JeniferC99 JeniferC99 deleted the users/agodavar/hipEventQuery-capture-mode-fix branch February 13, 2026 16:41
pianpwk pushed a commit to pytorch/pytorch that referenced this pull request Mar 17, 2026
This PR introduces a workaround for the HIP runtime bug (#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

### Code Changes

#### `ProcessGroupNCCL.cpp`
- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

#### `CUDAGraph.cpp/h`
- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: #176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 27, 2026
)

This PR introduces a workaround for the HIP runtime bug (pytorch#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel

(cherry picked from commit 5ae3a6f)
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
)

This PR introduces a workaround for the HIP runtime bug (pytorch#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
)

This PR introduces a workaround for the HIP runtime bug (pytorch#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

### Code Changes

#### `ProcessGroupNCCL.cpp`
- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

#### `CUDAGraph.cpp/h`
- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants