Fix for warning as default stream was used in enqueueV3#3191
Fix for warning as default stream was used in enqueueV3#3191keehyuna merged 2 commits intopytorch:mainfrom
Conversation
|
@keehyuna there is some code related to cudagraphs, can you check how it handles streams and perhaps write some docs on how the runtime is suppose to do streams in all cases? |
|
keehyuna
left a comment
There was a problem hiding this comment.
This is when problem happend. capture stream is changed to default and non default stream.

This is when torch.cuda.set_stream() is used. Non default stream is used for cuda graph/enqueueV3(). But stream is not restored after Forward()
.
This is proposed fix to keep side stream to cuda graph or enqueueV3()

| if (need_cudagraphs_record) { | ||
| // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph | ||
| c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; | ||
| compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id); |
There was a problem hiding this comment.
This was reverted to fix below assert() from torch. We don't share memory across captures, I think we can use internally created pool.
https://pytorch.org/docs/stable/notes/cuda.html#graph-memory-management
File "/root/trt/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 274, in forward
outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_ops.py", line 1113, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: it->second->use_count > 0 INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":2056, please report a bug to PyTorch
| self.shape_key = new_shape_key | ||
| self.cudagraph.reset() # type: ignore | ||
| if self.cudagraph: | ||
| self.cudagraph.reset() |
There was a problem hiding this comment.
self.cudagraph can be None when torch_compile backend is used.
self.cudagraph is initialized when cudagraphs mode is enabled. But this init was called at compile()
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py#L144-L145
|
@lanluo-nvidia please cherry-pick |
Description
torch.cuda.current_stream()/c10::cuda::getCurrentCUDAStream() always returns default stream and it leads running enqueueV3() with default stream.
torch.cuda.set_stream/c10::cuda::setCurrentCUDAStream is required to set current stream when new stream is acquired from pool
Fixes #3190
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: