Skip to content

Test that FSDP2 works with cuda graphs.#171835

Closed
galv wants to merge 3 commits intopytorch:mainfrom
galv:fsdp2-cuda-graph
Closed

Test that FSDP2 works with cuda graphs.#171835
galv wants to merge 3 commits intopytorch:mainfrom
galv:fsdp2-cuda-graph

Conversation

@galv
Copy link
Collaborator

@galv galv commented Jan 6, 2026

I initially wrote in #164264 that there was a missing wait_stream() call to put a stream into stream capture mode, but surprisingly since I made that issue the problem has been fixed. I was not able to locate the exact commit that coincidentally made that fix after a brief search. Since CachingHostAllocator supports memory allocation during stream capture since #167507, the purpose of this PR is simply to make sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still overlaps the all-gather and reduce-scatter streams with computation streams. To check for that, I applied this patch:

diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d87d7c..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":

I then inspection the json file manually to check for overlap.

Closes issue #164264

Fixes #164264

cc @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/171835

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit c88f2b1 with merge base 68370db (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 6, 2026
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jan 6, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: ezyang / name: Edward Z. Yang (c88f2b1)
  • ✅ login: galv / name: Daniel Galvez (0452df8, 8724462)
  • ✅ login: Skylion007 / name: Aaron Gokaslan (8724462)

@galv galv added the module: cuda graphs Ability to capture and then replay streams of CUDA kernels label Jan 6, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream()
call to put a stream into stream capture mode, but surprisingly since
I made that issue the problem has been fixed. I was not able to locate
the exact commit that coincidentally made that fix after a brief
search. Since CachingHostAllocator supports memory allocation during
stream capture since pytorch#167507, the purpose of this PR is simply to make
sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still
overlaps the all-gather and reduce-scatter streams with computation
streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264
Accidentally skipped the test. `python test/distributed/_composable/fsdp/test_fully_shard_training.py TestFullyShardCudaGraph.test_two_layer_fully_shard_cudagraph` ignores the unittest.skipIf decorator!

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Signed-off-by: Edward Yang <ezyang@meta.com>
@ezyang
Copy link
Contributor

ezyang commented Jan 7, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 7, 2026
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

]

static_input.copy_(replay_input)
graph.replay()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol if you attempted to do this for real you would be debugging for a long time why gradients are not accumulated, but for tests this is good enough

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, the idiom of doing model.zero_grad(set_to_none=True) before stream capture comes from the original pytorch blog post on cuda graph, so it wouldn't surprise me if most code out in the wild does this.

Very unfortunate for anyone who might be using cuda graphs this way and trying to do gradient accumulation over multiple minibatches 😬

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-aarch64 / linux-jammy-aarch64-py3.10 / test (openreg, 1, 1, lf.linux.arm64.m8g.4xlarge)

Details for Dev Infra team Raised by workflow job

@msaroufim
Copy link
Member

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: linux-aarch64 / linux-jammy-aarch64-py3.10 / test (openreg, 1, 1, lf.linux.arm64.m8g.4xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream() call to put a stream into stream capture mode, but surprisingly since I made that issue the problem has been fixed. I was not able to locate the exact commit that coincidentally made that fix after a brief search. Since CachingHostAllocator supports memory allocation during stream capture since pytorch#167507, the purpose of this PR is simply to make sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still overlaps the all-gather and reduce-scatter streams with computation streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264

Fixes pytorch#164264

Pull Request resolved: pytorch#171835
Approved by: https://github.com/ezyang, https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Edward Yang <ezyang@meta.com>
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
I initially wrote in pytorch#164264 that there was a missing wait_stream() call to put a stream into stream capture mode, but surprisingly since I made that issue the problem has been fixed. I was not able to locate the exact commit that coincidentally made that fix after a brief search. Since CachingHostAllocator supports memory allocation during stream capture since pytorch#167507, the purpose of this PR is simply to make sure that the support does not regress.

An important detail is that we need to make sure that cuda graph still overlaps the all-gather and reduce-scatter streams with computation streams. To check for that, I applied this patch:

```
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py
index c0831d8..c0fecdf787d 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_training.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py
@@ -1681,8 +1681,8 @@ class TestFullyShardCudaGraph(FSDPTest):
         device = torch.device(device_type.type, self.rank)
         torch.manual_seed(42)
         model = nn.Sequential(
-            nn.Linear(8, 8, bias=False),
-            nn.Linear(8, 8, bias=False),
+            nn.Linear(4096, 4096, bias=False),
+            nn.Linear(4096, 4096, bias=False),
         ).to(device)
         for param in model.parameters():
             dist.broadcast(param, src=0)
@@ -1694,7 +1694,7 @@ class TestFullyShardCudaGraph(FSDPTest):

         # warmup
         with torch.cuda.stream(stream):
-            input_tensor = torch.randn(4, 8, device=device)
+            input_tensor = torch.randn(4, 4096, device=device)
             output = model(input_tensor)
             output.sum().backward()
             model.zero_grad(set_to_none=True)
@@ -1711,7 +1711,7 @@ class TestFullyShardCudaGraph(FSDPTest):
             ]

         # equivalence check
-        with torch.cuda.stream(stream):
+        with torch.cuda.stream(stream), torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
             for _ in range(2):
                 replay_input = torch.randn(4, 8, device=device)
                 ref_output = model(replay_input)
@@ -1726,6 +1726,8 @@ class TestFullyShardCudaGraph(FSDPTest):
                 for graph_grad, ref_grad in zip(static_output_grads, ref_grads):
                     self.assertTrue(torch.equal(graph_grad, ref_grad))
                 model.zero_grad(set_to_none=True)
+                prof.step()
+        prof.export_chrome_trace(f"two_layer_fully_shard_cudagraph_{self.rank}.json")

 if __name__ == "__main__":
```

I then inspection the json file manually to check for overlap.

Closes issue pytorch#164264

Fixes pytorch#164264

Pull Request resolved: pytorch#171835
Approved by: https://github.com/ezyang, https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Edward Yang <ezyang@meta.com>
@weifengpy
Copy link
Contributor

@galv Thanks so much for the change! This means a lot to keep fsdp2 relevant in the era of grace cpu

@weifengpy weifengpy added the release notes: distributed (fsdp2) release notes category label Jan 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda graphs Ability to capture and then replay streams of CUDA kernels open source release notes: distributed (fsdp2) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cuda graph support for FSDP2 is lacking.

10 participants