-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Avoid recompilation for inputs integer number #132849
Description
🐛 Describe the bug
I am trying to apply torch.compile into vLLM, and I have met an issue that recompilation happens for each next token run. Thanks for the help from @anijain2305, I have narrow down the dynamo guard failure as

The max_decode_seq_len (defined here: https://github.com/vllm-project/vllm/blob/7b261092de3b008f7a2c218e338f4f8a025c93ee/vllm/attention/ops/paged_attn.py#L23) increases with number of next tokens and will be used as shape to create tensors inside kernel as in: https://github.com/intel/intel-extension-for-pytorch/blob/0a9762f68bbd83e35fdec7d8ac5b358313d4ebc2/csrc/cpu/aten/kernels/PagedAttentionKrnl.cpp#L216-L221
Error logs
https://gist.github.com/leslie-fang-intel/9baea13ac0204f85540952725b0c3060#file-error-log
Minified repro
Create a small testcase to simulate this issue: https://gist.github.com/leslie-fang-intel/9baea13ac0204f85540952725b0c3060
Versions
[conda] blas 1.0 mkl
[conda] intel-extension-for-pytorch 2.5.0+git25acc5f dev_0 <develop>
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-include 2024.0.0 pypi_0 pypi
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl-static 2024.0.0 pypi_0 pypi
[conda] mkl_fft 1.3.8 py310h5eee18b_0
[conda] mkl_random 1.2.4 py310hdb19cb5_0
[conda] numpy 1.26.0 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] torch 2.5.0a0+gitae708e9 dev_0 <develop>
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torchao 0.3.1 dev_0 <develop>
[conda] torchaudio 2.2.0a0+17a7081 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchmetrics 1.3.0.post0 pypi_0 pypi
[conda] torchtune 0.0.1 pypi_0 pypi
[conda] torchvision 0.19.0a0+5181a85 pypi_0 pypi
[conda] triton-nightly 3.0.0.post20240716052845 pypi_0 pypi