Skip to content

[ROCm] Avoid watchdog event queries during graph capture#176251

Closed
chinmaydk99 wants to merge 1 commit intopytorch:mainfrom
chinmaydk99:ck-fix-15310
Closed

[ROCm] Avoid watchdog event queries during graph capture#176251
chinmaydk99 wants to merge 1 commit intopytorch:mainfrom
chinmaydk99:ck-fix-15310

Conversation

@chinmaydk99
Copy link
Copy Markdown
Contributor

@chinmaydk99 chinmaydk99 commented Mar 3, 2026

This PR introduces a workaround for the HIP runtime bug (#177309) where hipEventQuery from a non-capturing thread invalidates graph captures on other threads, even in THREAD_LOCAL mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

Code Changes

ProcessGroupNCCL.cpp

  • queryEventWithRocmWatchdogCaptureWorkaround() wraps CUDAEvent::query() logic:
    • Watchdog calling during active capture: skips the query, returns false (not ready)
    • Otherwise queries normally, but catches hipErrorCapturedEvent / hipErrorStreamCaptureUnsupported from the watchdog and maps them to "not ready" for race conditions
  • RocmWatchdogEventQueryContextGuard thread-local guard set in runLoop() so the skip path only activates on the watchdog — main-thread wait()/isCompleted() unchanged
  • Timeout checks gated on !is_graph_capture_active() to avoid false positives while queries are skipped

CUDAGraph.cpp/h

  • is_graph_capture_active() reads the existing _currently_capturing_graphs map under its mutex
  • capture_end() erases the map entry before AT_CUDA_CHECK so the watchdog never sees stale state on error paths

All #ifdef USE_ROCM. TODO to remove once the HIP runtime fix ships.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176251

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending, 1 Unrelated Failure

As of commit 23503e4 with merge base e274aff (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: distributed (c10d) release notes category ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Mar 3, 2026
@jithunnair-amd jithunnair-amd added ciflow/periodic-rocm-mi355 Trigger "distributed" config CI on ROCm MI350/MI355 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300/MI325 labels Mar 3, 2026
@pragupta pragupta added the ciflow/periodic-rocm-mi200 Trigger "distributed" config CI on ROCm MI200 label Mar 3, 2026
@pragupta pragupta self-requested a review March 3, 2026 17:49
pragupta
pragupta previously approved these changes Mar 3, 2026
@pragupta pragupta added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2026
@chinmaydk99 chinmaydk99 marked this pull request as ready for review March 5, 2026 01:42
jeffdaily
jeffdaily previously approved these changes Mar 6, 2026
@jeffdaily
Copy link
Copy Markdown
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Mar 10, 2026

What is going on with these PRs? Watchdog event queries were already happening in thread-local mode, there should be no additional fixes needed. under what circumstances would putting watchdog in the thread-local mode
cc @eqy

@chinmaydk99
Copy link
Copy Markdown
Contributor Author

What is going on with these PRs? Watchdog event queries were already happening in thread-local mode, there should be no additional fixes needed. under what circumstances would putting watchdog in the thread-local mode cc @eqy

You're right that the watchdog already sets cudaStreamCaptureModeThreadLocal before event queries. On CUDA/NVIDIA this is sufficient.

The issue is that current HIP runtimes don't honor THREAD_LOCAL mode for hipEventQuery. Even with the mode switch, hipEventQuery still checks the global capture list and returns hipErrorStreamCaptureUnsupported if another thread has GLOBAL capture active, which invalidates the session. This is a known HIP runtime bug.

So the existing CUDAStreamCaptureModeGuard is the correct design, it just isn't effective on current HIP runtimes. This PR is a temporary workaround until the fixed runtime is baseline, at which point the guards can be removed.
Happy to work with @eqy to validate alternatives.

@jeffdaily
Copy link
Copy Markdown
Collaborator

@pytorchbot revert -c ghfirst -m "needs additional reviews, blocking revert of another PR"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
…re (#176251)"

This reverts commit a346446.

Reverted #176251 on behalf of https://github.com/jeffdaily due to needs additional reviews, blocking revert of another PR ([comment](#176251 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@chinmaydk99 your PR has been successfully reverted.

@chinmaydk99 chinmaydk99 force-pushed the ck-fix-15310 branch 2 times, most recently from 33bd997 to a3fdfc8 Compare March 12, 2026 02:34
@chinmaydk99 chinmaydk99 requested a review from ngimel March 12, 2026 02:34
return true;
}

// Must unconditionally return false here during watchdog + active capture:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... the semantics of cuda stream capture are a bit confusing so I wrote up more here: https://github.com/pytorch/pytorch/pull/140979/changes#diff-39e542d87359e7d5381d036cbcea9ec759fbe469578bcdc5693ce6cfab7f1a54R518-R547 (I never merged that change becuase I realized it wasn't necessary, so that comment now sits in obscurity). Pleae read it over to make sure we are on the same page.

It is quite likely that setting the current thread's capture error mode to cudaStreamCaptureModeThreadLocal or cudaStreamCaptureModeRelaxed would fix your issue. To be honest, I would need to read through this code in detail to figure out which one is right.

However, you write "on affected HIP runtimes" here. Is the purpose of all of your shennanigans here essentially that the current implementation in Rocm is incorrect?

Copy link
Copy Markdown
Contributor Author

@chinmaydk99 chinmaydk99 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this change is essentially a workaround for a HIP runtime bug on affected ROCm versions. We already set the watchdog thread to cudaStreamCaptureModeThreadLocal. Semantically that should be enough (and if it were fully honored, this workaround wouldn’t be needed).

The issue is that on affected HIP runtimes, watchdog-side hipEventQuery can still hit capture restrictions during cross-thread capture windows (#177309). So, this patch is a workaround for that runtime behavior, not a new intended semantic model. Once supported ROCm versions all include that fix, we can remove it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@galv is our response satisfactory here?



#if defined(USE_ROCM)
// Returns true when at least one CUDAGraph capture is currently active in this
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sort of obvious problem here is that cuda graphs (or hip graphs?) can exist outside of pytorch. The source of truth for "is a capture currently happening?" would have to be cuda or hip. This is the primary objection I have to this change on code composability grounds. Maybe you might think that this is a theoretical issue, but people have asked me about, e.g., mixing JAX and Pytorch in the same process before.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point, and I took a thorough look at the HIP API surface for a runtime-level source of truth here, but couldn’t find a process-wide capture query API, only stream-scoped capture-status APIs. So we can’t cleanly replace this with a universal runtime check in this path.

Given that limitation, this PR is a best-effort workaround for a real HIP runtime bug. It’s not ideal, but it addresses the immediate failure mode for our primary PyTorch-only usage while we wait for the runtime fix to become baseline.

Also, this is intended to be temporary: once the relevant HIP runtime fix is guaranteed across supported ROCm versions, the workaround in this patch is no longer needed. That’s why I added explicit TODOs to remove it once we can rely on fixed runtimes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@galv is our response satisfactory here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a timeline for when this bug is going to be fixed in ROCm? I feel if it's a short-term mitigation then it's much easier to justify some temporary code composability violations.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ngimel , yes, we expect this to be fixed in the next rocm release. We have internal tickets to track this fix for next rocm release.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next ROCm release is tentatively planned 3/26, and we're pushing to get this fix as part of that delivery. So this is on AMD's radar to fix urgently. IMHO this PR is a short-term fix for sure.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a short-term fix I'm ok with this, longer term I would much prefer to not see is_capture_in_progress queries in ProcessGroupNCCL

@jeffdaily
Copy link
Copy Markdown
Collaborator

@pytorchbot merge -f "target-determination job got stuck? all other trunk is passing, and 1 known flaky and other fails outside of this PR"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 27, 2026
)

This PR introduces a workaround for the HIP runtime bug (pytorch#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel

(cherry picked from commit 5ae3a6f)
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…rch#176251)

On HIP runtimes, `hipEventQuery` from the NCCL watchdog thread while another thread has GLOBAL stream capture active poisons the capture session (`hipErrorStreamCaptureUnsupported` → `hipErrorStreamCaptureInvalidated`).
This was observed as intermittent failures in FSDP + CUDA graph tests, confirmed by `TORCH_NCCL_BLOCKING_WAIT=1` making it pass (disables async watchdog).

Fix: add ROCm-only `is_graph_capture_active()` using the existing capture bookkeeping map, then use it in `ProcessGroupNCCL` to skip `hipEventQuery` and defer timeout checks while any capture is active. Also fix `capture_end()`
ordering so bookkeeping cleanup happens before error propagation.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/pragupta, https://github.com/jeffdaily
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…re (pytorch#176251)"

This reverts commit a346446.

Reverted pytorch#176251 on behalf of https://github.com/jeffdaily due to needs additional reviews, blocking revert of another PR ([comment](pytorch#176251 (comment)))
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
)

This PR introduces a workaround for the HIP runtime bug (pytorch#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
)

This PR introduces a workaround for the HIP runtime bug (pytorch#177309) where `hipEventQuery` from a non-capturing thread invalidates graph captures on other threads, even in `THREAD_LOCAL` mode(ROCm/rocm-systems#3176). The NCCL/RCCL watchdog's polling queries hit this.

### Code Changes

#### `ProcessGroupNCCL.cpp`
- `queryEventWithRocmWatchdogCaptureWorkaround()` wraps `CUDAEvent::query()` logic:
  - Watchdog calling during active capture: skips the query, returns false (not ready)
  - Otherwise queries normally, but catches `hipErrorCapturedEvent` / `hipErrorStreamCaptureUnsupported` from the watchdog and maps them to "not ready" for race conditions
- `RocmWatchdogEventQueryContextGuard` thread-local guard set in `runLoop()` so the skip path only activates on the watchdog — main-thread `wait()`/`isCompleted()` unchanged
- Timeout checks gated on `!is_graph_capture_active()` to avoid false positives while queries are skipped

#### `CUDAGraph.cpp/h`
- `is_graph_capture_active()` reads the existing `_currently_capturing_graphs` map under its mutex
- `capture_end()` erases the map entry before `AT_CUDA_CHECK` so the watchdog never sees stale state on error paths

All `#ifdef USE_ROCM`. TODO to remove once the HIP runtime fix ships.

Pull Request resolved: pytorch#176251
Approved by: https://github.com/jeffdaily, https://github.com/ngimel
chinmaydk99 added a commit to chinmaydk99/pytorch that referenced this pull request Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor-rocm-mi300 Trigger "inductor" config CI on ROCm MI300/MI325 ciflow/periodic-rocm-mi355 Trigger "distributed" config CI on ROCm MI350/MI355 ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: distributed (c10d) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants