Skip to content

Commit ad4e00a

Browse files
committed
[torchAPI] Inherit cuda stream from torch (#618)
When we work through `torch.compile` we didn't inherit the stream from torch. Fix it. PR fixes issue #563 Tested with the main branch of vllm cloned 28th of Nov. Options `--model meta-llama/Llama-3.1-8B-Instruct --max_model_len 4000` Before output was ``` about!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ``` after ``` about a cat named Whiskers Once upon a time, in a cozy little house on a quiet street, there lived a beautiful and elegant cat named Whiskers. Whiskers was a stunning feline with shimmering grey fur, bright green eyes, and a delicate pink nose. She was a gentle soul, with a soft and soothing purr that could calm even the most anxious of hearts. Whiskers lived with her loving owner, a kind old lady named Mrs. Jenkins, who adored her dearly. Mrs. Jenkins had rescued Whiskers from a shelter when she was just a tiny kitten, and the two had been inseparable ever since. Whiskers spent her days lounging in the sunbeams that streamed through the windows, chasing the occasional fly, and purring contentedly as Mrs. Jenkins stroked her soft fur. But Whiskers was more than just a lazy cat. She had a secret life, one that few people knew about. You``` ``` as expected.
1 parent ffdbde4 commit ad4e00a

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

python/hidet/graph/frontend/torch/dynamo_backends.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,11 @@ def __call__(self, *args):
164164
else:
165165
# ignore constant
166166
pass
167-
167+
# Inherited cuda stream from torch
168+
runtime_api.set_current_stream(torch.cuda.current_stream().cuda_stream)
169+
# Prepare inputs
168170
tensor_args = preprocess_inputs(tensor_args)
171+
# Run graph/model
169172
outputs = self.cgraph.run_async(tensor_args, output_to_torch_tensor=True)
170173
outputs: Sequence[torch.Tensor] = [
171174
tensor.torch() if isinstance(tensor, hidet.Tensor) else tensor for tensor in outputs

0 commit comments

Comments
 (0)