[Refactor][MLA]: Expose mla to torch.compile#39346
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the VLLM_MLA_EXPOSED_SPLIT feature to expose MLA prefill/decode batch splitting to torch.compile for improved fusion. It introduces several custom operators and modifies the compilation configuration and partition rules to support data-dependent batch sizes. Review feedback identifies a potential crash in the piecewise backend during AOT compilation for empty subgraphs, a missing check for null output_shape in the MLA forward pass, and a suggestion to use torch.cat for better optimization during tensor concatenation.
| assert self.graph is not None, "Eager fallback requires FX graph." | ||
| return self.graph(*args) |
There was a problem hiding this comment.
This assertion and subsequent call to self.graph will cause a crash when the model is loaded from an AOT compilation cache (where self.graph is None) and encounters an empty split subgraph (e.g., an all-prefill or all-decode microbatch). In AOT mode, the backend should handle shape 0 by either having a pre-compiled runnable or by returning appropriate empty/zero tensors without falling back to the FX graph module.
There was a problem hiding this comment.
My understanding is that 0 is not one of standard compile sizes to AOT compile on. If 0 is not included then the AOT graph, then we would get a crash later anyways. Failing fast, seems like a safer option as opposed to forcing users to do something non standard.
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
|
current performance: vllm serve deepseek-ai/DeepSeek-V2-Lite vllm bench serve --backend vllm --model deepseek-ai/DeepSeek-V2-Lite --dataset-name sharegpt --num-prompts 128 --profile --dataset-path /workspaces/vllm-dev/vllm/sharegpt.json ============ Serving Benchmark Result ============ Successful requests: 128 ================================================== |
|
pr performance: for repro: vllm bench serve --backend vllm --model deepseek-ai/DeepSeek-V2-Lite --dataset-name sharegpt --num-prompts 128 --profile --dataset-path /workspaces/vllm-dev/vllm/sharegpt.json ============ Serving Benchmark Result ============ Successful requests: 128 ================================================== |
8f98f9e to
f8325db
Compare
737a685 to
0d7149d
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
0d7149d to
682be0f
Compare
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
682be0f to
2d2d000
Compare
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
2d2d000 to
d53f813
Compare
Signed-off-by: parsshar-RH <parsshar@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
… marked cuda graph unsafe Signed-off-by: morrison-turnansky <mturnans@redhat.com>
…p_proj functional to eliminate temp buffer Signed-off-by: parsshar-RH <parsshar@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
…tention_decode Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: parsshar-RH <parsshar@redhat.com>
abb4d92 to
92d2f88
Compare
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: parsshar-RH <parsshar@redhat.com>
fc1955a to
db97579
Compare
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1 similar comment
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Btw I had this idea on how to unify the exposed/non-exposed paths, using the def wrap_if_exposed(op_name: str):
def decorator(func):
# could optionally register the custom op automatically here as well,
# but that might be more effort than is worth
@wraps(func)
def wrapper(self: "MLAAttention", *args, **kwargs):
if not self.exposed:
return func(self, *args, **kwargs)
return getattr(torch.ops.vllm, op_name)(self.layer_name, *args, **kwargs)
return wrapper
return decorator
class MLAAttention(...):
def __init__(self):
self.exposed = ...
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.exposed:
return self.inner_forward(x)
return torch.ops.vllm.mla_attention(self.layer_name, x)
def inner_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Inner forward with decomposed method calls."""
num_prefills = self.split_batch()
output = torch.empty_like(x)
output[num_prefills:] = self.forward_mha(x[num_prefills:])
output[:num_prefills] = self.forward_mqa(x[:num_prefills])
return output
def forward_mqa(self, x: torch.Tensor) -> torch.Tensor:
...
@wrap_if_exposed("mla_forward_mha")
def forward_mha(self, x: torch.Tensor) -> torch.Tensor:
... |
e3185dc to
db97579
Compare
|
Hi @morrison-turnansky, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
Expose MLA to torch compile by splitting up custom op.
Purpose
Fixes #26516
Test Plan
I didn't see a great way to test this besides an end to end test, which I compared against the un-exposed path i.e. the current default implementation. Locally I ran in eager, and stock torch compile as well with exact matches, but I didn't want to put a lot of long tests in the repo.
Test Result
Tests are passing. I also did some local bench marking. Initial results showed a slowdown, which is most likely due to the additional graph breaks necessitated by splitting the custom op. If there is interest we can expose slightly fewer operations to remove some of the graph breaks and close the performance gap.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.