Skip to content

Make CachingHostAllocator work with memory pools.#167507

Closed
galv wants to merge 14 commits intogh/galv/2/basefrom
gh/galv/2/head
Closed

Make CachingHostAllocator work with memory pools.#167507
galv wants to merge 14 commits intogh/galv/2/basefrom
gh/galv/2/head

Conversation

@galv
Copy link
Collaborator

@galv galv commented Nov 11, 2025

Stack from ghstack (oldest at bottom):

Both allocation to a cuda graph's private pool via stream capture and
allocation to a memory pool in non-stream-captured code are supported.

In the case of stream capture, we refuse to reuse a host memory block
as soon as record_event() is called on that block. This is to prevent
a stream-captured CUDA kernel from reading different contents from a
memory block than would be read if, counterfactually, this CUDA
kernels were running eagerly on a cuda stream.

See
#161583 (comment)
for elaboration.

This is lacking test cases for pagedable host memory copies. We must
make sure that record_event() does not fail in that case.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167507

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 9 Unrelated Failures

As of commit dfbb053 with merge base 3854d69 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@galv
Copy link
Collaborator Author

galv commented Nov 17, 2025

I have found an issue while testing this locally. Still not quite there yet.

@galv galv added the release notes: cuda release notes category label Nov 17, 2025
Khanaksahu pushed a commit to Khanaksahu/pytorch that referenced this pull request Nov 17, 2025
Both allocation to a cuda graph's private pool via stream capture and
allocation to a memory pool in non-stream-captured code are supported.

In the case of stream capture, we refuse to reuse a host memory block
as soon as record_event() is called on that block. This is to prevent
a stream-captured CUDA kernel from reading different contents from a
memory block than would be read if, counterfactually, this CUDA
kernels were running eagerly on a cuda stream.

See
pytorch/pytorch#161583 (comment)
for elaboration.

This is lacking test cases for pagedable host memory copies. We must
make sure that record_event() does not fail in that case.


ghstack-source-id: 7f22464
Pull-Request: pytorch/pytorch#167507
@galv
Copy link
Collaborator Author

galv commented Nov 17, 2025

The current issue is pageable host memory. It is legal to t1.copy_(t2, non_blocking=True) where either t1 or t2 is pageable host memory (while the other tensor is device memory).

Without loss of generality, suppose that t1 is a tensor backed by pageable host memory.

If t1 was allocated during stream capture, the user would expect that t1's backing memory will be kept alive (even if t1 itself dies), because that is the semantics with GPU and pinned CPU memory allocations. Meanwhile, if t1 was allocated before stream capture, it is the user's responsibility to keep t1 alive.

We have no way to distinguish whether t1 was allocated before or after stream capture, without doing something clunky like capturing the entire state of the process's memory map when capture_begin() is called. Even then, that does not work since malloc() can reuse previously allocated pages.

Fortunately, what we can do is distinguish with the CUDA APIs whether a given host pointer points to pinned host memory or pageable host memory. In #167508, I can add a warning if a particular cudaMemcpyAsync() happens to read from pageable host memory. This is probably the best I can do without something weird like shadow memory.

[ghstack-poisoned]
@galv
Copy link
Collaborator Author

galv commented Nov 18, 2025

This PR is ready for review.

FYI @eee4017 @ngimel in case you are interested.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments, lmk if you want to change the pools design

// First, try to allocate from the free list of the chosen pool
auto* block = get_free_block(roundSize, pool);
if (block) {
block->was_allocated_during_stream_capture_ = stream_is_capturing(get_current_stream());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, to protect perf of the common case, first check captures_underway. You may hide in in current_stream_is_capturing with no args, so inside that function you'd first check that captures_underway is non-empty, end then get current stream and check its status.

}
}

// TODO: Rethink how this is implemented. Should it take a pool id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I think it would be valuable to free just the pinned blocks associated with a pool

[ghstack-poisoned]

def test_unpinned_memory_use(self):
# It is allowed to call copy_(non_blocking=True) on pageable
# host memory. TODO: We should test that a warning is emitted
Copy link
Collaborator Author

@galv galv Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will complete this test in #167508

[ghstack-poisoned]
@galv
Copy link
Collaborator Author

galv commented Dec 2, 2025

@eee4017 @eqy @syed-ahmed if you want to review this, now is the time to do so. I removed support for private pools outside of stream capture, which drastically reduces the size of the PR. The test cases have also shrunk as well.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 17, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@galv
Copy link
Collaborator Author

galv commented Dec 17, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 12 checks: s390x-periodic / linux-manylinux-2_28-py3-cpu-s390x / test (default, 3, 10, linux.s390x), s390x-periodic / linux-manylinux-2_28-py3-cpu-s390x / test (default, 4, 10, linux.s390x), periodic / linux-jammy-cuda12.4-py3.10-gcc11 / test (legacy_nvidia_driver, 4, 5, lf.linux.g4dn.4xlarge.nvidia.gpu, unstable), periodic / linux-jammy-cuda12.4-py3.10-gcc11 / test (legacy_nvidia_driver, 3, 5, lf.linux.g4dn.4xlarge.nvidia.gpu, unstable), periodic / linux-jammy-cuda12.4-py3.10-gcc11 / test (legacy_nvidia_driver, 1, 5, lf.linux.g4dn.4xlarge.nvidia.gpu, unstable), periodic / linux-jammy-cuda12.4-py3.10-gcc11 / test (legacy_nvidia_driver, 2, 5, lf.linux.g4dn.4xlarge.nvidia.gpu, unstable), periodic / linux-jammy-cuda12.4-py3.10-gcc11 / test (legacy_nvidia_driver, 5, 5, lf.linux.g4dn.4xlarge.nvidia.gpu, unstable), periodic / linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck / test (default, 2, 8, lf.linux.g5.4xlarge.nvidia.gpu, module:slowgradcheck), periodic / linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck / test (default, 1, 8, lf.linux.g5.4xlarge.nvidia.gpu, module:slowgradcheck), periodic / linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck / test (default, 7, 8, lf.linux.g5.4xlarge.nvidia.gpu, module:slowgradcheck), periodic / linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck / test (default, 6, 8, lf.linux.g5.4xlarge.nvidia.gpu, module:slowgradcheck), periodic / linux-jammy-cuda12.8-py3.10-gcc11-debug / test (default, 1, 7, lf.linux.g6.4xlarge.experimental.nvidia.gpu, oncall:debug-build)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@galv galv mentioned this pull request Dec 17, 2025
pytorchmergebot pushed a commit that referenced this pull request Dec 18, 2025
Testing: `pytest -k test_split_with_sizes_copy_out test/test_torch.py`

This is no longer needed after #167507

Fixes #169607

Pull Request resolved: #170710
Approved by: https://github.com/eqy
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
Testing: `pytest -k test_split_with_sizes_copy_out test/test_torch.py`

This is no longer needed after pytorch#167507

Fixes pytorch#169607

Pull Request resolved: pytorch#170710
Approved by: https://github.com/eqy
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
Testing: `pytest -k test_split_with_sizes_copy_out test/test_torch.py`

This is no longer needed after #167507

Fixes #169607

Pull Request resolved: #170710
Approved by: https://github.com/eqy
galv pushed a commit to galv/pytorch that referenced this pull request Jan 6, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream()
call to put a stream into stream capture mode, but surprisingly since
I made that issue the problem has been fixed. I was not able to locate
the exact commit that coincidentally made that fix after a brief
search. Since CachingHostAllocator supports memory allocation during
stream capture since pytorch#167507, the purpose of this PR is simply to make
sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still
overlaps the all-gather and reduce-scatter streams with computation
streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264
galv added a commit to galv/pytorch that referenced this pull request Jan 6, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream()
call to put a stream into stream capture mode, but surprisingly since
I made that issue the problem has been fixed. I was not able to locate
the exact commit that coincidentally made that fix after a brief
search. Since CachingHostAllocator supports memory allocation during
stream capture since pytorch#167507, the purpose of this PR is simply to make
sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still
overlaps the all-gather and reduce-scatter streams with computation
streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264
pytorchmergebot pushed a commit that referenced this pull request Jan 7, 2026
I initially wrote in #164264 that there was a missing wait_stream() call to put a stream into stream capture mode, but surprisingly since I made that issue the problem has been fixed. I was not able to locate the exact commit that coincidentally made that fix after a brief search. Since CachingHostAllocator supports memory allocation during stream capture since #167507, the purpose of this PR is simply to make sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still overlaps the all-gather and reduce-scatter streams with computation streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue #164264

Fixes #164264

Pull Request resolved: #171835
Approved by: https://github.com/ezyang, https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Edward Yang <ezyang@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
Both allocation to a cuda graph's private pool via stream capture and
allocation to a memory pool in non-stream-captured code are supported.

In the case of stream capture, we refuse to reuse a host memory block
as soon as record_event() is called on that block. This is to prevent
a stream-captured CUDA kernel from reading different contents from a
memory block than would be read if, counterfactually, this CUDA
kernels were running eagerly on a cuda stream.

See
pytorch#161583 (comment)
for elaboration.

This is lacking test cases for pagedable host memory copies. We must
make sure that record_event() does not fail in that case.
Pull Request resolved: pytorch#167507
Approved by: https://github.com/eqy, https://github.com/eee4017, https://github.com/ngimel
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
Testing: `pytest -k test_split_with_sizes_copy_out test/test_torch.py`

This is no longer needed after pytorch#167507

Fixes pytorch#169607

Pull Request resolved: pytorch#170710
Approved by: https://github.com/eqy
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream() call to put a stream into stream capture mode, but surprisingly since I made that issue the problem has been fixed. I was not able to locate the exact commit that coincidentally made that fix after a brief search. Since CachingHostAllocator supports memory allocation during stream capture since pytorch#167507, the purpose of this PR is simply to make sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still overlaps the all-gather and reduce-scatter streams with computation streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264

Fixes pytorch#164264

Pull Request resolved: pytorch#171835
Approved by: https://github.com/ezyang, https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Edward Yang <ezyang@meta.com>
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream() call to put a stream into stream capture mode, but surprisingly since I made that issue the problem has been fixed. I was not able to locate the exact commit that coincidentally made that fix after a brief search. Since CachingHostAllocator supports memory allocation during stream capture since pytorch#167507, the purpose of this PR is simply to make sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still overlaps the all-gather and reduce-scatter streams with computation streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264

Fixes pytorch#164264

Pull Request resolved: pytorch#171835
Approved by: https://github.com/ezyang, https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Edward Yang <ezyang@meta.com>
@github-actions github-actions bot deleted the gh/galv/2/head branch January 17, 2026 02:19
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
Only allocation to a cuda graph's private pool via stream capture is
supported.

Allocation to a memory pool in non-stream-captured code is not
supported. There is no obvious usecase at this time for that.

In stream capture, we refuse to reuse a host memory block as soon as
record_event() is called on that block. This is to prevent a
stream-captured CUDA kernel from reading different contents from a
memory block than would be read if, counterfactually, this CUDA
kernels were running eagerly on a cuda stream.

See
pytorch/pytorch#161583 (comment)
for elaboration.

ghstack-source-id: d1e30fc
Pull-Request: pytorch/pytorch#167507
pytorchmergebot pushed a commit that referenced this pull request Feb 11, 2026
… capturing (#174724)

This matches the behavior of CUDACachingAllocator.cpp.

If a user wants to prevent a pin_memory() call from a data loading thread from disrupting their stream capture, they use "thread_local" stream caputre mode for that stream capture.

Requested by @ngimel as a follow up to #167507
Pull Request resolved: #174724
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/pull ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: cuda release notes category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants