Skip to content

[v1] Add encoder-only/cross attention support to Triton Attention backend#31406

Merged
Isotr0py merged 30 commits intovllm-project:mainfrom
Isotr0py:triton-mha-kernel
Jan 5, 2026
Merged

[v1] Add encoder-only/cross attention support to Triton Attention backend#31406
Isotr0py merged 30 commits intovllm-project:mainfrom
Isotr0py:triton-mha-kernel

Conversation

@Isotr0py
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py commented Dec 27, 2025

Purpose

Motivation

  1. We don't have Whisper support for Turing/Volta GPUs because of FA's cc limit.
  2. Furthermore, FlexAttention's encoder-only attention's speed is still not ideal for FP32.
  3. After xformers deprecation ([Core] Deprecate xformers #29262), we want to add a Triton MMEncoderAttention backend to give a balanced solution between FA and SDPA for incompatable head_size.

Introduction

Test Plan

Whisper:

VLLM_ATTENTION_BACKEND=TRITON_ATTN python examples/offline_inference/audio_language.py -m whisper --num-prompts 5

Embedding models with sliding window (Test should use Triton backend by default now):

pytest -s -v ./tests/models/language/pooling_mteb_test/test_st_projector.py

Test Result

Whisper:

(EngineCore_DP0 pid=3934) INFO 12-28 17:16:27 [cuda.py:315] Using AttentionBackendEnum.TRITON_ATTN backend.
(EngineCore_DP0 pid=3934) WARNING 12-28 17:16:27 [vllm.py:1382] `torch.compile` is turned on, but the model openai/whisper-large-v3-turbo does not support it. Please open an issue on GitHub if you want it to be supported.
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:27 [weight_utils.py:550] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.74it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.74it/s]
(EngineCore_DP0 pid=3934) 
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:28 [default_loader.py:308] Loading weights took 0.68 seconds
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:28 [gpu_model_runner.py:3728] Model loading took 1.5076 GiB memory and 1.143090 seconds
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:29 [gpu_model_runner.py:4541] Encoder cache will be initialized with a budget of 2240 tokens, and profiled with 1 audio items of the maximum feature size.
(EngineCore_DP0 pid=3934) WARNING 12-28 17:16:29 [processing.py:1153] WhisperProcessor did not return `BatchFeature`. Make sure to match the behaviour of `ProcessorMixin` when implementing custom processors.
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:30 [gpu_worker.py:363] Available KV cache memory: 11.64 GiB
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:30 [kv_cache_utils.py:1305] GPU KV cache size: 305,216 tokens
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:30 [kv_cache_utils.py:1310] Maximum concurrency for 448 tokens per request: 227.10x
Capturing CUDA graphs (decode, FULL): 100%|███████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.38it/s]
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:32 [gpu_model_runner.py:4682] Graph capturing finished in 2 secs, took 0.02 GiB
(EngineCore_DP0 pid=3934) INFO 12-28 17:16:32 [core.py:272] init engine (profile, create kv cache, warmup model) took 3.65 seconds
(EngineCore_DP0 pid=3934) WARNING 12-28 17:16:32 [vllm.py:860] No piecewise cudagraph for executing cascade attention. Will fall back to eager execution if a batch runs into cascade attentions
INFO 12-28 17:16:32 [llm.py:344] Supported tasks: ['transcription']
Adding requests:  20%|████████████████                                                                | 1/5 [00:05<00:23,  5.81s/it]WARNING 12-28 17:16:52 [processing.py:1153] WhisperProcessor did not return `BatchFeature`. Make sure to match the behaviour of `ProcessorMixin` when implementing custom processors.
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00,  1.16s/it]
Processed prompts: 100%|████████████████████████| 5/5 [00:02<00:00,  2.39it/s, est. speed input: 2.39 toks/s, output: 112.54 toks/s]
 The first words I spoke in the original phonograph, a little piece of practical poetry. Mary had a little lamb, its streets were quite as snow, and everywhere that Mary went, the lamb was sure to go.
 The first words I spoke in the original phonograph, a little piece of practical poetry. Mary had a little lamb, its streets were quite as slow, and everywhere that Mary went, the lamb was sure to go.
 The first words I spoke in the original phonograph, a little piece of practical poetry. Mary had a little lamb, its streets were quite as slow, and everywhere that Mary went the lamb was sure to go.
 The first words I spoke in the original phonograph: a little piece of practical poetry. Mary had a little lamb, its streets were quite as slow, and everywhere that Mary went the lamb was sure to go.
 The first words I spoke in the original phonograph: a little piece of practical poetry. Mary had a little lamb, its streets was white as snow, and everywhere that Mary went the lamb was sure to go.

Encoder-only models: Tests should still pass with Triton backend


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@mergify mergify bot added the v1 label Dec 27, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for encoder-only and cross-attention to the Triton attention backend. This is achieved by introducing a new prefill attention Triton kernel that can handle non-causal attention, and a new execution path in TritonAttentionImpl for encoder attention types. The changes also include refactoring in several multi-modal models to likely align with this new attention mechanism.

My review identified two critical issues in the new implementation:

  1. An incorrect masking logic for non-causal sliding window attention in the new Triton kernel, which results in a one-sided window instead of a bidirectional one.
  2. A type mismatch when passing the sliding_window parameter to the new attention function, which would lead to a runtime error.

I have provided suggestions to fix both issues. After these are addressed, the changes look good and provide a valuable extension to the Triton backend.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py marked this pull request as ready for review December 28, 2025 17:30
@Isotr0py
Copy link
Copy Markdown
Member Author

Also cc @NickLucche and @noooop about Whisper/Encoder-only models respectively.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

https://github.com/vllm-project/vllm/blob/6011b0a5603ef118d07e28f1ff178e53a4611bc4/model_executor/models/ernie45_vl.py#L155-L159
P1 Badge Tensor-parallel vision attention reshapes oversized QKV

When tp_size > 1, QKVParallelLinear emits Q/K/V blocks of length equal to the full projection size, but split_qkv now reshapes them assuming only projection_size / tp_size elements (num_attention_heads_per_partition * hidden_size_per_attention_head). Without the previous gather-and-repartition step, x.view(*new_shape) will raise a size mismatch (or, if forced, misassign heads) on multi-GPU tensor-parallel runs of the vision encoder. The same pattern appears in Glm4vVisionAttention and SiglipAttention, so any vision model using tensor parallelism will fail at runtime.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py
Copy link
Copy Markdown
Member Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel for memory-efficient prefill attention, which includes support for sliding window attention. The TritonAttention backend is updated to leverage this new kernel, enabling support for encoder-only and encoder attention types by adding a dedicated _forward_encoder_attention method. A review comment identifies a critical bug in the new Triton kernel's _fwd_kernel function, specifically in the calculation of start_n_limit for the backward sliding window, which is currently incorrect and could lead to skipped key blocks and erroneous attention outputs. A corrected formula for start_n_limit is provided in the review.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this backend to whisper-specific CI tests (test_transcription_validation_whisper.py)?

@Isotr0py
Copy link
Copy Markdown
Member Author

can we add this backend to whisper-specific CI tests (test_transcription_validation_whisper.py)?

Sure, I think we can use FP32 to test Whisper, it should use Triton backend by default now since FA doesn't support FA:

@pytest.mark.core_model
@pytest.mark.cpu_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("enforce_eager", [True, False])
@create_new_process_for_each_test("spawn")
def test_models(
hf_runner,
vllm_runner,
model: str,
dtype: str,
num_logprobs: int,
input_audios,
enforce_eager: bool,
) -> None:
check_model_available(model)
if current_platform.is_cpu() and not enforce_eager:
pytest.skip("Skipping test for CPU with non-eager mode")
run_test(
hf_runner,
vllm_runner,
input_audios,
model,
dtype=dtype,
max_model_len=448,
max_tokens=200,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
enforce_eager=enforce_eager,
)

@Isotr0py Isotr0py marked this pull request as draft December 29, 2025 13:40
@Isotr0py Isotr0py marked this pull request as ready for review December 31, 2025 15:01
@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Jan 4, 2026

The performance of TRITON_ATTN looks good.

https://github.com/noooop/snippet/tree/main/benchmarks/triton_attention

image

X-axis: Throughput (request/s)
Y-axis: Latency, Time needed for one step (ms) <- logarithmic scale
The curve lower right is better ↘

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Jan 5, 2026

The performance of TRITON_ATTN looks good.

https://github.com/noooop/snippet/tree/main/benchmarks/triton_attention

image X-axis: Throughput (request/s) Y-axis: Latency, Time needed for one step (ms) <- logarithmic scale The curve lower right is better ↘

How about the performance of bfloat16?

@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Jan 5, 2026

How about the performance of bfloat16?

image

bfloat16 is slightly faster than float16, but the trend is the same.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Isotr0py This is LGTM from whisper-side, both from accuracy and latency at fp16 (don't really have a comparison to run at fp32 for enc-dec models).
Looking forward to the MMEncoderAttention backend to get a few more meaningful datapoints in benchmarks.

Thanks for your work!

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 5, 2026
@Isotr0py Isotr0py merged commit 6aa5b18 into vllm-project:main Jan 5, 2026
51 of 52 checks passed
@Isotr0py Isotr0py deleted the triton-mha-kernel branch January 5, 2026 16:00
@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Jan 6, 2026

https://buildkite.com/vllm/ci/builds/45546/steps/canvas

the failure of today's Full CI run - daily Language Models Test (Extended Pooling) was caused by this PR.

Running the Language Models Test (Extended Pooling) on this PR confirmed the issue.

https://buildkite.com/vllm/ci/builds/45475/steps/canvas?sid=019b8e61-59ae-4d96-ad0a-b100f167c05c


Failure below is not related to this PR.
models/language/pooling_mteb_test/test_nemotron.py::test_embed_models_mteb[model_info0] in Language Models Test (MTEB) is a bit flaky lately.

LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…kend (vllm-project#31406)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
## Summary

Cherry-pick upstream bug fixes for RHAIIS 3.3.1 onto `rhai/0.13.0`. All
fixes are from upstream vLLM `main` and address critical bugs affecting
RHAIIS 3.3.0. Other releases (3.2.2, EAx) will be done separately.

**Jira Epic:**
[INFERENG-4743](https://issues.redhat.com/browse/INFERENG-4743)

## Cherry-picked commits (chronological order)

| # | Upstream PR | Jira | Summary |
|---|------------|------|---------|
| 1 | [vllm-project#30550](vllm-project#30550) |
[INFERENG-5106](https://issues.redhat.com/browse/INFERENG-5106) |
Support using chat template as custom score template for reranking
models |
| 2 | [vllm-project#31406](vllm-project#31406) |
[INFERENG-4800](https://issues.redhat.com/browse/INFERENG-4800) | Add
encoder-only/cross attention support to Triton Attention backend |
| 3 | [vllm-project#34243](vllm-project#34243) |
[INFERENG-4746](https://issues.redhat.com/browse/INFERENG-4746) | Fix
Llama-4 attn quantization by correctly permuting scales for rope (int8,
fp8) |
| 4 | [vllm-project#34454](vllm-project#34454) |
[INFERENG-5032](https://issues.redhat.com/browse/INFERENG-5032) | Fix
structured output in multi-turn GPT-OSS (content:null with json_object)
|
| 5 | [vllm-project#34507](vllm-project#34507) |
[INFERENG-5038](https://issues.redhat.com/browse/INFERENG-5038) | Fix
fused MoE int32 overflow in stride*offset for large models |
| 6 | [vllm-project#35085](vllm-project#35085) |
[INFERENG-5028](https://issues.redhat.com/browse/INFERENG-5028) |
Gracefully disable AllReduceFusionPass on GPUs without multicast support
|
| 7 | [vllm-project#35456](vllm-project#35456) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) |
Replace assert with ValueError for response_format validation
(completions) |
| 8 | [vllm-project#35510](vllm-project#35510) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) | Add
response_format validation to chat completions endpoint |


## Conflict resolutions

<details>
<summary><b>#1 — llama-nemotron-embed / score-template support
(vllm-project#30550)</b>: Clean cherry-pick, no conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#2 — Triton Attention (vllm-project#31406)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#3 — Llama-4 attn quant (vllm-project#34243)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly. 4 intermediate upstream commits touch `llama4.py` but
the fix targets a self-contained block.
</details>

<details>
<summary><b>vllm-project#4 — GPT-OSS multi-turn (vllm-project#34454)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly despite 3 intermediate upstream commits that refactored
imports in `gptoss_reasoning_parser.py`. The fix logic (adding
`eom_token_id` early-exit check in `is_reasoning_end`) was independent
of the import changes.
</details>

<details>
<summary><b>vllm-project#5 — Fused MoE int32 overflow (vllm-project#34507)</b>: Conflicts in 2
files</summary>

**`vllm/model_executor/layers/fused_moe/fused_moe.py`**: ~30
intermediate upstream commits refactored `fused_moe_kernel` with
conditional `naive_block_assignment` logic that doesn't exist in
`rhai/0.13.0`. Resolved by keeping our simpler code and applying only
the int64 cast fix:
- `fused_moe_kernel_gptq_awq`: added `.to(tl.int64)` to `tl.load()`
result
- `fused_moe_kernel`: added `offs_token = offs_token.to(tl.int64)`
before `token_mask`

**`tests/kernels/moe/test_moe.py`**: Upstream test changes depend on
`make_dummy_moe_config()` from intermediate refactors. Resolved by
keeping our existing test code (no test changes).
</details>

<details>
<summary><b>vllm-project#6 — AllReduceFusionPass multicast (vllm-project#35085)</b>: Conflict
due to file rename + API change</summary>

Upstream moved `collective_fusion.py` →
`compilation/passes/fusion/allreduce_rms_fusion.py` and changed the API
from `trtllm_create_ipc_workspace_for_all_reduce_fusion()` to
`create_allreduce_fusion_workspace()`. Resolved by applying the
try/except wrapper around our existing
`trtllm_create_ipc_workspace_for_all_reduce_fusion()` call in
`collective_fusion.py`. The error handling logic (catching RuntimeError
with "multicast" in message, logging warning, returning early) is
identical to upstream.
</details>

<details>
<summary><b>vllm-project#7 — response_format validation for completions
(vllm-project#35456)</b>: Conflict due to file restructuring</summary>

Upstream split `protocol.py` into `completion/protocol.py` and
`chat_completion/protocol.py`. Our branch still has the monolithic
`protocol.py`. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`CompletionRequest` in our `protocol.py`
- Using `ValueError` instead of upstream's `VLLMValidationError` (which
doesn't exist in our branch; `ValueError` is already handled as 400 Bad
Request in `serving_engine.py`)
- Test additions from upstream applied cleanly to
`test_completion_error.py`
</details>

<details>
<summary><b>vllm-project#8 — response_format validation for chat completions
(vllm-project#35510)</b>: Conflict due to file restructuring</summary>

Same file restructuring issue as vllm-project#6. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/chat_completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`ChatCompletionRequest` in our `protocol.py`
- Only accepting the `test_json_schema_response_format_missing_schema`
test from the conflict (discarding ~140 lines of intermediate upstream
tests that reference non-existent paths in our branch)
</details>

## Test plan

- [ ] Verify `llama-nemotron-embed-1b-v2` works correctly with the
backported score-template / bidirectional model support
- [ ] Verify Llama-4 quantized model loads correctly with int8/fp8
attention quantization
- [ ] Verify GPT-OSS multi-turn chat with `json_object` response_format
returns valid content
- [ ] Verify large MoE models (e.g. Qwen3.5-397B) don't crash with int32
overflow
- [ ] Verify MoE model loading on H200 GPUs (without multicast)
gracefully falls back
- [ ] Verify `response_format: {type: "json_schema"}` without
`json_schema` field returns 400 (not 500) for both `/v1/completions` and
`/v1/chat/completions`
- [ ] Verify encoder models (e.g. Whisper) work with Triton attention
backend on ROCm


[INFERENG-4743]:
https://redhat.atlassian.net/browse/INFERENG-4743?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4800]:
https://redhat.atlassian.net/browse/INFERENG-4800?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4746]:
https://redhat.atlassian.net/browse/INFERENG-4746?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5032]:
https://redhat.atlassian.net/browse/INFERENG-5032?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5038]:
https://redhat.atlassian.net/browse/INFERENG-5038?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

[INFERENG-5106]:
https://redhat.atlassian.net/browse/INFERENG-5106?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants