Skip to content

[Refactor][MLA]: Expose mla to torch.compile#39346

Open
morrison-turnansky wants to merge 26 commits into
vllm-project:mainfrom
morrison-turnansky:issue-34823-mla-custom-op-unwrap-unoptimized
Open

[Refactor][MLA]: Expose mla to torch.compile#39346
morrison-turnansky wants to merge 26 commits into
vllm-project:mainfrom
morrison-turnansky:issue-34823-mla-custom-op-unwrap-unoptimized

Conversation

@morrison-turnansky

@morrison-turnansky morrison-turnansky commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

Expose MLA to torch compile by splitting up custom op.

Purpose

Fixes #26516

Test Plan

I didn't see a great way to test this besides an end to end test, which I compared against the un-exposed path i.e. the current default implementation. Locally I ran in eager, and stock torch compile as well with exact matches, but I didn't want to put a lot of long tests in the repo.

Test Result

Tests are passing. I also did some local bench marking. Initial results showed a slowdown, which is most likely due to the additional graph breaks necessitated by splitting the custom op. If there is interest we can expose slightly fewer operations to remove some of the graph breaks and close the performance gap.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the VLLM_MLA_EXPOSED_SPLIT feature to expose MLA prefill/decode batch splitting to torch.compile for improved fusion. It introduces several custom operators and modifies the compilation configuration and partition rules to support data-dependent batch sizes. Review feedback identifies a potential crash in the piecewise backend during AOT compilation for empty subgraphs, a missing check for null output_shape in the MLA forward pass, and a suggestion to use torch.cat for better optimization during tensor concatenation.

Comment on lines +361 to +362
assert self.graph is not None, "Eager fallback requires FX graph."
return self.graph(*args)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This assertion and subsequent call to self.graph will cause a crash when the model is loaded from an AOT compilation cache (where self.graph is None) and encounters an empty split subgraph (e.g., an all-prefill or all-decode microbatch). In AOT mode, the backend should handle shape 0 by either having a pre-compiled runnable or by returning appropriate empty/zero tensors without falling back to the FX graph module.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that 0 is not one of standard compile sizes to AOT compile on. If 0 is not included then the AOT graph, then we would get a crash later anyways. Failing fast, seems like a safer option as opposed to forcing users to do something non standard.

Comment thread vllm/model_executor/layers/attention/mla_attention.py
Comment thread vllm/model_executor/layers/attention/mla_attention.py Outdated
@mergify

mergify Bot commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@ProExpertProg ProExpertProg added the verified Run pre-commit for new contributors without triggering other tests label Apr 8, 2026
@mergify

mergify Bot commented Apr 10, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @morrison-turnansky.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 10, 2026
@morrison-turnansky

morrison-turnansky commented Apr 10, 2026

Copy link
Copy Markdown
Contributor Author

current performance:

vllm serve deepseek-ai/DeepSeek-V2-Lite
--trust-remote-code
-cc.mode=VLLM_COMPILE
-cc.cudagraph_mode=FULL_AND_PIECEWISE
-cc.use_inductor_graph_partition=true
--profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'

vllm bench serve --backend vllm --model deepseek-ai/DeepSeek-V2-Lite --dataset-name sharegpt --num-prompts 128 --profile --dataset-path /workspaces/vllm-dev/vllm/sharegpt.json

profiler_out_0.txt

============ Serving Benchmark Result ============

Successful requests: 128
Failed requests: 0
Benchmark duration (s): 9.83
Total input tokens: 31879
Total generated tokens: 27379
Request throughput (req/s): 13.02
Output token throughput (tok/s): 2784.11
Peak output token throughput (tok/s): 6243.00
Peak concurrent requests: 128.00
Total token throughput (tok/s): 6025.82
---------------Time to First Token----------------
Mean TTFT (ms): 543.37
Median TTFT (ms): 625.09
P99 TTFT (ms): 645.95
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 17.00
Median TPOT (ms): 12.66
P99 TPOT (ms): 86.88
---------------Inter-token Latency----------------
Mean ITL (ms): 11.92
Median ITL (ms): 11.58
P99 ITL (ms): 23.41

==================================================

@morrison-turnansky

morrison-turnansky commented Apr 10, 2026

Copy link
Copy Markdown
Contributor Author

pr performance:

for repro:
VLLM_MLA_EXPOSED_SPLIT=1
vllm serve deepseek-ai/DeepSeek-V2-Lite
--trust-remote-code
-cc.mode=VLLM_COMPILE
-cc.cudagraph_mode=FULL_AND_PIECEWISE
-cc.use_inductor_graph_partition=true
--profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'

vllm bench serve --backend vllm --model deepseek-ai/DeepSeek-V2-Lite --dataset-name sharegpt --num-prompts 128 --profile --dataset-path /workspaces/vllm-dev/vllm/sharegpt.json

profiler_out_0.txt

============ Serving Benchmark Result ============

Successful requests: 128
Failed requests: 0
Benchmark duration (s): 15.27
Total input tokens: 31879
Total generated tokens: 26575
Request throughput (req/s): 8.38
Output token throughput (tok/s): 1740.69
Peak output token throughput (tok/s): 6028.00
Peak concurrent requests: 128.00
Total token throughput (tok/s): 3828.79
---------------Time to First Token----------------
Mean TTFT (ms): 3135.06
Median TTFT (ms): 3175.92
P99 TTFT (ms): 3304.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 37.26
Median TPOT (ms): 24.28
P99 TPOT (ms): 129.87
---------------Inter-token Latency----------------
Mean ITL (ms): 21.72
Median ITL (ms): 11.85
P99 ITL (ms): 28.04

==================================================

@morrison-turnansky morrison-turnansky force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from 8f98f9e to f8325db Compare April 10, 2026 15:14
@mergify mergify Bot removed the needs-rebase label Apr 10, 2026
@parsshar-RH parsshar-RH force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from 737a685 to 0d7149d Compare April 13, 2026 13:10
@mergify

mergify Bot commented Apr 15, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @morrison-turnansky.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 15, 2026
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
@morrison-turnansky morrison-turnansky force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from 0d7149d to 682be0f Compare April 15, 2026 19:40
@mergify

mergify Bot commented Apr 15, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@parsshar-RH parsshar-RH force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from 682be0f to 2d2d000 Compare April 16, 2026 10:29
@mergify mergify Bot removed the needs-rebase label Apr 16, 2026
@mergify

mergify Bot commented Apr 16, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@morrison-turnansky morrison-turnansky force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from 2d2d000 to d53f813 Compare April 16, 2026 15:48
parsshar-RH and others added 10 commits April 21, 2026 12:07
Signed-off-by: parsshar-RH <parsshar@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
… marked cuda graph unsafe

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
…p_proj functional to eliminate temp buffer

Signed-off-by: parsshar-RH <parsshar@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
…tention_decode

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: parsshar-RH <parsshar@redhat.com>
@parsshar-RH parsshar-RH force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from abb4d92 to 92d2f88 Compare April 21, 2026 12:07
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
@mergify

mergify Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: parsshar-RH <parsshar@redhat.com>
@parsshar-RH parsshar-RH force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from fc1955a to db97579 Compare April 27, 2026 11:10
@mergify

mergify Bot commented Apr 27, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

1 similar comment
@mergify

mergify Bot commented Apr 30, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@ProExpertProg

ProExpertProg commented May 1, 2026

Copy link
Copy Markdown
Collaborator

Btw I had this idea on how to unify the exposed/non-exposed paths, using the wrap_if_exposed decorator:

def wrap_if_exposed(op_name: str):
    def decorator(func):
        # could optionally register the custom op automatically here as well, 
        # but that might be more effort than is worth

        @wraps(func)
        def wrapper(self: "MLAAttention", *args, **kwargs):
            if not self.exposed:
                return func(self, *args, **kwargs)
            return getattr(torch.ops.vllm, op_name)(self.layer_name, *args, **kwargs)
        return wrapper
    return decorator


class MLAAttention(...):
    def __init__(self):
        self.exposed = ...
        ...

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.exposed:
            return self.inner_forward(x)
        return torch.ops.vllm.mla_attention(self.layer_name, x)

    def inner_forward(self, x: torch.Tensor) -> torch.Tensor:
        """Inner forward with decomposed method calls."""
        num_prefills = self.split_batch()

        output = torch.empty_like(x)
        output[num_prefills:] = self.forward_mha(x[num_prefills:])
        output[:num_prefills] = self.forward_mqa(x[:num_prefills])

        return output

    def forward_mqa(self, x: torch.Tensor) -> torch.Tensor:
        ...

    @wrap_if_exposed("mla_forward_mha")
    def forward_mha(self, x: torch.Tensor) -> torch.Tensor:
        ...

@parsshar-RH parsshar-RH force-pushed the issue-34823-mla-custom-op-unwrap-unoptimized branch from e3185dc to db97579 Compare May 4, 2026 15:34
@mergify

mergify Bot commented May 4, 2026

Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify

mergify Bot commented May 23, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @morrison-turnansky.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase v1 verified Run pre-commit for new contributors without triggering other tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Refactor][MLA]: Lift prefill/decode split into compiled region

4 participants