Skip to content

Commit 4115354

Browse files
eellisonpytorchmergebot
authored andcommitted
Use wait stream instead of synchronize() in cudagraph warmup (#117578)
Fix for #113895 There are three phases to cudagraph trees. Warmup, recording, and execution. On recording and execution we are executing under the current_stream. In warmup we execute under a side stream that we also use for cudagraph recording so as to reuse memory. After we execute on the side stream we need to sync the current stream to the side stream. Previously there was a `torch.cuda.synchronize` but not a `torch.cuda.current_stream().wait_stream(stream)`. This PR removes the global sync and adds a wait_stream. I have confirmed that it fixes #113895. It's not entirely clear me why torch.cuda.synchronize would be insufficient - I would have thought the global sync would encompass the stream to stream sync. However, we do have a number of [instances](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/compile_fx.py#L748-L749) throughout the code base where we do a stream->stream sync after the global sync so clearly I am missing something here. In any case the stream->stream sync is better perf than a global synchronize. Pull Request resolved: #117578 Approved by: https://github.com/zdevito
1 parent 560213d commit 4115354

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

test/inductor/test_cudagraph_trees.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,26 @@ def foo(args):
708708
self.assertEqual(node.cached_tensor_outputs, [None])
709709
self.assertEqual(node.unaliased_in_all_paths, [False])
710710

711+
def test_warmup_stream_sync(self):
712+
def foo(args):
713+
x = args[0]
714+
args.clear()
715+
x_orig = x
716+
for _ in range(100):
717+
x = x @ x
718+
return (x,)
719+
720+
inp = torch.rand([4096, 4096], device="cuda")
721+
ref = foo([inp])[0]
722+
torch.cuda.synchronize()
723+
724+
user_stream = torch.cuda.Stream()
725+
with torch.cuda.stream(user_stream):
726+
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
727+
out = foo_cg([inp])[0]
728+
y = out + 1
729+
self.assertEqual(y, ref + 1)
730+
711731
def test_unaligned_static_parameter(self):
712732
def gen_inp():
713733
inp = torch.ones([20], device="cuda")

torch/_inductor/cudagraph_trees.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,8 @@ def _use_cuda_memory_pool_manager(device, mem_pool, stream):
518518
torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
519519
torch._C._cuda_releasePool(device, mem_pool)
520520

521+
torch.cuda.current_stream().wait_stream(stream)
522+
521523

522524
def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
523525
if not isinstance(t, torch.Tensor):
@@ -610,9 +612,6 @@ def get_non_cudagraph_inps():
610612
), get_history_recording():
611613
out = self.wrapped_function.model(new_inputs)
612614

613-
# sync up stream used in `_use_cuda_memory_pool_manager` - TODO - wait stream instead ?
614-
torch.cuda.synchronize()
615-
616615
assert len(new_inputs) == 0
617616

618617
# sdpa returns cpu tensors when not recording cuda graph

0 commit comments

Comments
 (0)