Skip to content

[CUDA] Reuse blocks with record_stream during CUDA Graph capture in the CUDACachingAllocator#158352

Closed
eee4017 wants to merge 16 commits intopytorch:mainfrom
eee4017:remove-cudagraph-defer-reclaiming
Closed

[CUDA] Reuse blocks with record_stream during CUDA Graph capture in the CUDACachingAllocator#158352
eee4017 wants to merge 16 commits intopytorch:mainfrom
eee4017:remove-cudagraph-defer-reclaiming

Conversation

@eee4017
Copy link
Collaborator

@eee4017 eee4017 commented Jul 15, 2025

Introduction

During CUDA Graph capture, the CUDA caching allocator currently defers reclaiming blocks until capture ends. This is because CUDA forbids querying events recorded during capture (the CUDA operation is not executed during the capture stage), so the allocator cannot use its normal event-based logic. However, capture records an DAG (we call it capturing graph) of work. We can use the capturing graph to determine when a block’s old lifetime is fully before future work, and safely reuse it within the same capture.

This PR adds an experimental flag graph_capture_record_stream_reuse: True|False (default: False). When enabled, the allocator inserts lightweight free markers and uses capture ordering to decide if a freed block is safe to reuse during capture. If the proof cannot be established, we fall back to the existing post-capture path.

Terms

  • Free marker: A capture-legal no-op (created with cudaGraphAddEmptyNode) inserted after the last captured use of the block on each stream that used it.
  • Terminal: The set of the lastest operations of the stream (or the capturing graph). Any newly captured op on that stream will attach after all nodes in this set. For a stream currently capturing, it is the set of nodes returned in dependencies_out by cudaStreamGetCaptureInfo.

When can we reuse a block during capture?

Strong Rule (Graph-Wide Safety)

This rule provides a universal guarantee that a block is safe for reuse by any stream in the graph.

A block is safe to reuse if every free marker is a predecessor of every terminal of all active streams in the graph.

Why it's safe:

This rule establishes a strict global ordering. Since any new operation on any stream must be appended after that stream's terminals, this condition guarantees that the block's new lifetime begins only after its old lifetime has completely ended everywhere. This prevents lifetime overlaps when the graph is replayed, ensuring correctness.

Per-stream Rule (A Practical Optimization)

The strong rule, while safe, is often unnecessarily restrictive. The DeviceCachingAllocator introduces a crucial constraint that allows for a simpler check.

In DeviceCachingAllocator, get_free_block only returns blocks whose block->stream == p.stream(). In other words, we never reuse a block on a stream different from the allocation stream. This means we don't need to verify safety across the entire graph. We only need to confirm that the block is safe to reuse from the perspective of its own allocation stream.

Reuse a block for allocations on stream S if every free marker is a predecessor of every node in the terminal set of S.

In short, a block is considered reusable on stream S as long as all marker marking it "free" are guaranteed to complete before any new work that might need it on stream S begins.

Implementation

  • On free(block) during capture
    • For each stream in block->stream_uses and the allocation stream, insert a free marker (empty node) and make it that stream’s tail.
    • If we cannot place markers for all such streams (for example, a stream is not in capture), defer to the post-capture path.
    • Otherwise, store the marker handles and keep the block in the capture-private structures.
  • On allocate(stream) during capture (attempt per-stream reclaim)
    • Query the allocation stream S’s terminal via cudaStreamGetCaptureInfo.
    • For each deferred block, check whether it is allocated on this stream, and each of its free markers is a predecessor of the terminal.
      • If yes, hand the block to S for immediate reuse within the same capture.
      • If no, keep it deferred; it will be reconsidered as capture progresses and S’s terminal advances.
  • On capture end
    • Any still-deferred blocks follow the existing post-capture reclamation (event insertion/polling). External behavior remains unchanged if we cannot prove safety during capture.

Examples (2 streams)

pytorch-remove-cudagraph-defer-reclaiming (6)
  • Case 0 — Unsafe
    The two frees are not ordered with respect to each other. For stream 1, the other stream’s free marker does not precede this stream’s terminal, so the per-stream condition fails.
    Counterexample intuition for the unsafe setups: imagine f2(x) runs for a long time. If DeviceCachingAllocator reused block x on a stream whose terminal is not ordered after the free markers, the new lifetime could overlap the old one on replay, risking use-after-free or data corruption. The per-stream rule prevents exactly this.
  • Case 1 — Reusable on stream 1
    Stream 1’s terminal is after both frees, so every free marker precedes stream 1’s terminal. The block is reusable for allocations on stream 1.
  • Case 2 — Not reusable on stream 2, but this cannot occur in DeviceCachingAllocator
    This depicts reusing the block on stream 2 while stream 1’s free is not yet ordered before stream 2’s terminal. Though the block is not safe to reuse on stream 2, DeviceCachingAllocator will not choose that block for stream 2 anyway: get_free_block rejects blocks whose stream != p.stream(). So this case is unreachable.
  • Case 3 — Safe (strong rule holds)
    In this scenario, the terminal nodes of all streams are positioned after the block's free markers, satisfying the strong rule. This guarantees the block is safe for reuse by any stream in the capturing graph. However, since DeviceCachingAllocator only reuses a block on its original allocation stream, verifying this strong condition is unnecessary. We only need to ensure the per-stream rule is met for the specific stream requesting the block.
  • Case 4 — Freeing after a join
    See the note below.

Edge Case: Freeing after a join

Our current dependency tracking has a limitation in scenarios where a block is freed after a stream join, see @galv's comments here).

In the case 4, we have a missed opportunity. Because the block's usage is not explicitly marked, we cannot determine that the block's actual last use may have occurred much earlier, long before the join. Then, we must wait for the subsequent join before the block can be reused.

Thanks

Thanks to @galv for his great idea around graph parsing and empty nodes.

cc @ptrblck @msaroufim @eqy @jerryzh168 @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng

@eee4017 eee4017 requested review from eqy and syed-ahmed as code owners July 15, 2025 16:31
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158352

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3755c3e with merge base 480c739 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eee4017
Copy link
Collaborator Author

eee4017 commented Jul 15, 2025

@pytorchbot label "topic: not user facing" "module: cuda" "open source"

@pytorch-bot pytorch-bot bot added module: cuda Related to torch.cuda, and CUDA support in general topic: not user facing topic category labels Jul 15, 2025
@eee4017 eee4017 changed the title [CUDA] Remove CUDAGraph defer reclaiming to reduce the memory usage capture [CUDA] Reduce memory usage during CUDA Graph capture by removing deferred reclaiming Jul 15, 2025
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is needed to do the cudaEventQuery, it might hurt debuggability a bit if it means user-disallowed code can slip through here. However this is probably unlikely as I would assume the Python thread would not be making progress while we are here

Copy link
Collaborator

@eqy eqy Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this missing additional synchronization?
During capture, we are synchronizing externally, which is OK, but on replay it looks like nothing is checking the equivalent of events completing before allowing free_block()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During replay the allocator is inactive: every tensor that the graph touches was allocated before cudaStreamEndCapture.
The replay kernel launches are entirely device‑side; they do not call back into CUDACachingAllocator, so no additional free_block() (or new allocations) are called while the graph is replaying. Therefore the only time we need to cudaEventQuery is at capture‑time. After capture finishes the host reclaims the block once the event is ready; the same block is never freed a second time during replay.

To record “block X is done” on a stream that is being captured (though the stream is not the main capturing stream), we still need to emit a cudaEventRecordExternal node.

I did not find any way to express this dependency without creating a graph node. Those nodes are harmless at replay but are indeed redundant once the block is freed on the host.

We could clean them up post‑capture with cudaGraphRemoveDependencies / cudaGraphDestroyNode, but that would add non‑trivial graph‑editing code at the post-capturing stage. I currently choose a simpler approach.

Copy link
Collaborator

@eqy eqy Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am more concerned about the follow toy example sequence of events where on replay there is a missing dependency, mainly due to the fact that no synchronization check is being done at replay time because there is no eventQuery.

An example for one-side stream recorded + allocating stream
the cudaFree is for illustration purposes only, it could also represent the allocation being improperly reused for another tensor etc., because it is marked as free by the allocator
image
Here the replay still preserves ordering even if the eventQueries are gone during replay.

I think the proof-of-concept would be to insert a new (non-external) event, when we know the original eventQuery is done (at capture time), and synchronize the allocating stream on it. Note that if there is more than one side-stream recorded, we could record the additional event when the last event is queried-complete at capture time, or when the event is seen as finished from the host.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking a bit more I think the above handles the case where the allocation gets reused on the allocating stream, but I'm not sure if this would handle the allocation being recycled on a side stream...

I am also thinking about how bad it would be to explicitly cudaFree the block for that case, as that would be expensive at capture but cheap on replay...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the above in-place but I think the description was inaccurate. Here we are only concerned with the lifetime of the allocation being respected by the CUDA runtime during replay, and correctness with regard to stream ordering is the user's responsibility.

The only possible danger that was illustrated above was cudaFree not happening at the correct place, but that is disallowed/not done during capture.

@eqy eqy added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 16, 2025
@eqy
Copy link
Collaborator

eqy commented Jul 16, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased remove-cudagraph-defer-reclaiming onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout remove-cudagraph-defer-reclaiming && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the remove-cudagraph-defer-reclaiming branch from baab1b3 to 17cd6d2 Compare July 16, 2025 01:48
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 16, 2025
@eqy eqy added the module: cuda graphs Ability to capture and then replay streams of CUDA kernels label Jul 16, 2025
@ezyang ezyang requested review from eellison and ezyang July 17, 2025 03:00
@ezyang
Copy link
Contributor

ezyang commented Jul 17, 2025

Some basic education for me: where can I learn how cuda graph external events work exactly?

@ezyang ezyang requested a review from ngimel July 17, 2025 03:15
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before making changes to very sensitive parts of the codebase, please take time to familiarize yourself with why things were done the way they were done, and what fundamental reasons are there for cudaEventQuery to be uncapturable.

@galv
Copy link
Collaborator

galv commented Jul 17, 2025

Some basic education for me: where can I learn how cuda graph external events work exactly?

@ezyang Events are not themselves inherently "external". You just pass an "external" flag to cudaEventRecordWithFlags and cudaStreamWaitEvent during stream capture to a cuda graph to indicate that these calls should be turned into "event record" and "event wait" nodes. That's literally all that they mean.

Meanwhile, cuda event records and waits without the external flags (i.e., the default flags) have the semantics described here: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cross-stream-dependencies-and-events Put another way, "internal" (note that this isn't official terminology) events are the way that you can have "forks" and "joins" in your cuda graph. Without them, every cuda graph would be a single chain, with at most one edge exiting each node.

You can review my PR #155372 to see how these API's are used in practice, since I realize that they are very rare.

@galv
Copy link
Collaborator

galv commented Jul 17, 2025

If I understand correctly, what you are trying to tackle is the problem that memory allocations that are put into a block's stream_uses cannot be recycled. AFAIK, the only way stream_uses ever gets populated is via CUDACachingAllocator::recordStream(), which is pretty rare. Can you mention concrete examples when this is used?

Overall, though, there are a lot of fishy things in here. You should not be using those external record events as far as I know. Is anyone even waiting on them or querying them?

My guess is that you want something more like what I did in my implementation to make stream capture work with host memory allocations: https://github.com/pytorch/pytorch/pull/146924/files#diff-5d3beb56bf9b6f380f91d5f6f063480ce2e14ca15c415d59d153436018089223R396-R404

There, a similar API calling CachingHostAllocator_recordEvent() is used to mark every "usage" of a particular block. I use empty nodes as "sentinels" to mark these usages. Then at each free() call, I do the following check:

If (1) the reference count of a cudaHostAlloc()
created block has gone to 0 (tracked via the allocated field in
HostBlock), and (2) there is a path from every empty node created
by a call to record_event() to the current node in stream capture,
then this block can be reused, and is therefore moved to the free
list.

I think this is an appropriate course of action, and to be honest, I was of the impression that CUDACachingAllocator already did something like that (I understand ~50% of it), but I guess I may be wrong. But anyway please speak to me internally for more help on this. Test cases that motivate this attempt are the most helpful thing you can start with.

@eee4017
Copy link
Collaborator Author

eee4017 commented Jul 17, 2025

Hi, thanks for the comments @ngimel and @galv.

I think this example helps illustrate why we might want to attempt reclaiming blocks during capture time. Please take a look at this gist. It shows a case where some blocks could be reclaimed, but currently are not during the capture stage.

image

@eee4017 eee4017 marked this pull request as draft July 17, 2025 08:34
@eee4017
Copy link
Collaborator Author

eee4017 commented Jul 17, 2025

Thanks everyone for your comments. I believe this will require some fundamental changes. I’ll let you know once it’s ready for review.

@eee4017 eee4017 force-pushed the remove-cudagraph-defer-reclaiming branch from df9c507 to 3755c3e Compare September 4, 2025 09:15
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/nightly Trigger all jobs we run nightly (nightly.yml) ciflow/rocm Trigger "default" config CI on ROCm labels Sep 4, 2025
@eee4017 eee4017 added ciflow/trunk Trigger trunk jobs on your pull request ciflow/nightly Trigger all jobs we run nightly (nightly.yml) ciflow/rocm Trigger "default" config CI on ROCm labels Sep 4, 2025
@pytorch-bot

This comment was marked as outdated.

@pytorch-bot

This comment was marked as outdated.

@pytorch-bot pytorch-bot bot removed ciflow/nightly Trigger all jobs we run nightly (nightly.yml) ciflow/trunk Trigger trunk jobs on your pull request labels Sep 4, 2025
@pytorch-bot

This comment was marked as outdated.

@pytorch-bot pytorch-bot bot removed the ciflow/rocm Trigger "default" config CI on ROCm label Sep 4, 2025
@jeffdaily jeffdaily added ciflow/trunk Trigger trunk jobs on your pull request ciflow/nightly Trigger all jobs we run nightly (nightly.yml) ciflow/rocm Trigger "default" config CI on ROCm labels Sep 4, 2025
@jeffdaily
Copy link
Collaborator

I approved CI and added the labels back but they didn't trigger CI flows?

@jeffdaily jeffdaily removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/nightly Trigger all jobs we run nightly (nightly.yml) ciflow/rocm Trigger "default" config CI on ROCm labels Sep 4, 2025
@eee4017
Copy link
Collaborator Author

eee4017 commented Sep 4, 2025

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/nightly Trigger all jobs we run nightly (nightly.yml) ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: cuda Related to torch.cuda, and CUDA support in general open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.