Skip to content

[V1][Hybrid] Mamba Prefix Caching with align mode#30877

Merged
heheda12345 merged 136 commits into
vllm-project:mainfrom
peakcrosser7:ups/mamba_prefix_cache_align
Jan 23, 2026
Merged

[V1][Hybrid] Mamba Prefix Caching with align mode#30877
heheda12345 merged 136 commits into
vllm-project:mainfrom
peakcrosser7:ups/mamba_prefix_cache_align

Conversation

@peakcrosser7

@peakcrosser7 peakcrosser7 commented Dec 17, 2025

Copy link
Copy Markdown
Contributor

The cleaned-up version of #29272

Purpose

This PR enhances the design of #28176 , adopting the same memory layout as FullAttention while adding support for decode caching and speculative decoding.

The core idea of this Mamba Prefix-Caching implementation (referred to as LPC) is to directly cache Mamba states through block-aligned scheduling. This approach enables rapid support for Prefix-caching in Mamba models without modifications to the underlying kernel code. Furthermore, it maintains full compatibility with Speculative-Decoding/MTP/EAGLE.

Currently, this solution supports all Mamba model architectures including GDN, Mamba1, Mamba2, and Short Conv Attention, and has been adapted for relevant Mamba models such as Qwen3-Next-80B-A3B-Instruct and LFM2-700M.

  • Note: Speculative decoding is temporarily disabled in this PR as there are still corner-case bugs when using with prefix-caching in align mode.

Usage

To enable this feature, start the engine with the --enable-prefix-caching and --mamba-cache-mode align flags.

Design Details

Block-Aligned Scheduling

Following the design in #28176 , requests in the prefill phase are scheduled in multiples of block_size. This ensures that the Mamba states can be mapped to a specific block's hash value.

The prefix-caching stores variable-length chunk states—i.e., the number of tokens (or the incremental length) associated with each cached Mamba state may vary, but it is always a multiple of block_size.

Scheduler Logic with Mamba Prefix-Caching Enabled:

  • Decode requests: Scheduling logic remains unchanged
  • Prefill requests:
    • The number of tokens scheduled per step must be an integer multiple of block_size, except for the final chunk of the request.
    • The last prefill chunk is split to align with block_size, ensuring its size is ≤ block_size. This maximizes the length of the prompt that can be cached during the prefill phase.

Block Allocation Design

Prefill Stage

During the prefill stage, requests are scheduled at a block-aligned chunk granularity. For a single scheduling step consisting of chunk_len tokens, the system allocates chunk_len // block_size blocks:

  • Mamba State Block: Only the last block in the sequence is physically allocated to store the Mamba state.
  • Placeholder Null-Blocks: The preceding (chunk_len // block_size) - 1 blocks are populated with null-blocks (placeholders).

Note on Speculative Decoding (SPS): In the prefill stage with SPS enabled, the initial execution requires the allocation of gamma additional speculative blocks, which are subsequently reused in following steps.
prefill_alloc

Decode Stage

Since only a small number of tokens are scheduled per step during decoding, the allocation logic is consistent with FullAttention, where blocks are incrementally allocated one by one.
decode_alloc

Prefix Caching Logic

Scheduler-side Logic

Similar to the FullAttention prefix-caching logic. Only immutable blocks that store Mamba states are cached (excluding the null-blocks). And prefix matching is performed via a reverse hash lookup that requires only a single block to be matched.
scheduler_logic

Worker-side Logic

Prefill Phase:
The Preprocess stage is responsible for copying Mamba states before the model forward:

Condition 1: Copy the Mamba state from the previous step to the current step.
worker_prefill_cond1

Condition 2: Copy the Mamba state from the prefix-cache hit block to the current step.
worker_prefill_cond2

Decode Phase:
Without Speculative Decoding: The logic remains consistent with the standard Prefill Phase.

With Speculative Decoding:
The Preprocess stage copies Mamba states when a new block is allocated:

  • Note: Be aware that the conv state and temporal state may reside in different blocks depending on the num_accepted_tokens.
worker_decode_sps_pre

After receiving the full number of tokens corresponding to the previous block, the Post-process stage copies the Mamba state back to the previous block.
worker_decode_sps_post

Test Plan

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
import time

def main():
    MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct"
    # You can use other Mamba models for testing
    # MODEL = "ibm-granite/granite-4.0-tiny-preview"
    # MODEL = "LiquidAI/LFM2-700M"
    # MODEL = "ai21labs/AI21-Jamba-Reasoning-3B"
    PROMPT_MULTIPLE = 310
    sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
    prefix = ( # examples/offline_inference/prefix_caching.py
        "You are an expert school principal, skilled in effectively managing "
        "faculty and staff. Draft 10-15 questions for a potential first grade "
        "Head Teacher for my K-12, all-girls', independent school that emphasizes "
        "community, joyful discovery, and life-long learning. The candidate is "
        "coming in for a first-round panel interview for a 8th grade Math "
        "teaching role. They have 5 years of previous teaching experience "
        "as an assistant teacher at a co-ed, public school with experience "
        "in middle school math teaching. ")
    prefix2 = ("Based on these information, fulfill "
                "the following paragraph: ")
    prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is"
    print('Prompt length:', len(prompt))
    for APC in [False, True]:
        engine = LLM(
            model=MODEL, enable_prefix_caching=APC, 
            max_num_batched_tokens=8192,
            block_size=64,
            tensor_parallel_size=4,
            gpu_memory_utilization=0.8, 
            disable_log_stats=False,
            mamba_cache_mode="align",
        )
        for i in range(3):
            if i == 0:
                print('Warm-up')
            if i == 1:
                print('Measuring')
                start_time = time.time()
            outputs = engine.generate(prompt, sampling_params)
            print('APC:', APC, i, f"Generated text: {outputs[0].outputs[0].text!r}")
            for m in engine.llm_engine.get_metrics():
                if 'vllm:prefix_cache_hits' in m.name:
                    print(m.name, m.value)
        print('APC:', APC, "loop took --- %s seconds ---" % (time.time() - start_time))
        del engine
        cleanup_dist_env_and_memory()


if __name__ == "__main__":
    main()

Test Result

Warm-up

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 12.78it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:12<00:00, 12.64s/it, est. speed input: 2552.19 toks/s, output: 10.13 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:12<00:00, 12.64s/it, est. speed input: 2552.19 toks/s, output: 10.13 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:12<00:00, 12.64s/it, est. speed input: 2552.19 toks/s, output: 10.13 toks/s]
APC: False 0 Generated text: " __________, and I am the Head of School at __________. I am thrilled to welcome you to our interview today. Our school is a K-12, all-girls', independent school that emphasizes community, joyful discovery, and lifelong learning. We believe that every girl has the potential to thrive when nurtured in an environment that values curiosity, collaboration, and courage. As we consider candidates for our 8th grade Math teaching role, we are looking for educators who not only have strong content knowledge and pedagogical skills, but who also embody our core values and are excited to contribute to a vibrant, supportive, and girl"
vllm:prefix_cache_hits 0
Measuring

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 14.60it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it, est. speed input: 17128.75 toks/s, output: 67.97 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it, est. speed input: 17128.75 toks/s, output: 67.97 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it, est. speed input: 17128.75 toks/s, output: 67.97 toks/s]
APC: False 1 Generated text: " __________, and I am the Head of School at __________. I am thrilled to welcome you to our interview today. Our school is a K-12, all-girls', independent school that emphasizes community, joyful discovery, and lifelong learning. We believe that every girl has the potential to thrive when nurtured in an environment that values curiosity, collaboration, and courage. As we consider candidates for our 8th grade Math teaching role, we are looking for educators who not only have strong content knowledge and pedagogical skills, but who also embody our core values and are excited to contribute to a vibrant, supportive, and girl"
vllm:prefix_cache_hits 0

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 14.93it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it, est. speed input: 16923.13 toks/s, output: 67.16 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it, est. speed input: 16923.13 toks/s, output: 67.16 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it, est. speed input: 16923.13 toks/s, output: 67.16 toks/s]
^[[0;36m(Worker_TP0 pid=695046)^[[0;0m INFO 12-30 18:59:07 [multiproc_executor.py:709] Parent process exited, terminating worker
^[[0;36m(Worker_TP3 pid=695049)^[[0;0m INFO 12-30 18:59:07 [multiproc_executor.py:709] Parent process exited, terminating worker
^[[0;36m(Worker_TP1 pid=695047)^[[0;0m INFO 12-30 18:59:07 [multiproc_executor.py:709] Parent process exited, terminating worker
^[[0;36m(Worker_TP2 pid=695048)^[[0;0m INFO 12-30 18:59:07 [multiproc_executor.py:709] Parent process exited, terminating worker
APC: False 2 Generated text: " __________, and I am the Head of School at __________. I am thrilled to welcome you to our interview today. Our school is a K-12, all-girls', independent school that emphasizes community, joyful discovery, and lifelong learning. We believe that every girl has the potential to thrive when nurtured in an environment that values curiosity, collaboration, and courage. As we consider candidates for our 8th grade Math teaching role, we are looking for educators who not only have strong content knowledge and pedagogical skills, but who also embody our core values and are excited to contribute to a vibrant, supportive, and girl"
vllm:prefix_cache_hits 0
APC: False loop took --- 3.9298272132873535 seconds ---

Warm-up

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 14.51it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:12<00:00, 12.92s/it, est. speed input: 2497.21 toks/s, output: 9.91 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:12<00:00, 12.92s/it, est. speed input: 2497.21 toks/s, output: 9.91 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:12<00:00, 12.92s/it, est. speed input: 2497.21 toks/s, output: 9.91 toks/s]
APC: True 0 Generated text: " __________, and I am the Head of School at __________. I am thrilled to welcome you to our interview today. Our school is a K-12, all-girls', independent school that emphasizes community, joyful discovery, and lifelong learning. We believe that every girl has the potential to thrive when nurtured in an environment that values curiosity, collaboration, and courage. As we consider candidates for our 8th grade Math teaching role, we are looking for educators who not only have strong content knowledge and pedagogical skill, but who also bring a deep commitment to fostering a classroom culture where girls feel seen, heard, and"
vllm:prefix_cache_hits 0
Measuring

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 15.06it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s, est. speed input: 33534.00 toks/s, output: 133.07 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s, est. speed input: 33534.00 toks/s, output: 133.07 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s, est. speed input: 33534.00 toks/s, output: 133.07 toks/s]
APC: True 1 Generated text: " __________, and I am the Head of School at __________. I am thrilled to welcome you to our interview today. Our school is a K-12, all-girls', independent school that emphasizes community, joyful discovery, and lifelong learning. We believe that every girl has the potential to thrive when nurtured in an environment that values curiosity, collaboration, and courage. As we consider candidates for our 8th grade Math teaching role, we are looking for educators who not only have strong content knowledge and pedagogical skill, but who also bring a deep commitment to fostering a classroom culture where girls feel seen, heard, and"
vllm:prefix_cache_hits 32096

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 14.97it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s, est. speed input: 33591.79 toks/s, output: 133.30 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s, est. speed input: 33591.79 toks/s, output: 133.30 toks/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s, est. speed input: 33591.79 toks/s, output: 133.30 toks/s]
^[[0;36m(Worker_TP1 pid=698526)^[[0;0m INFO 12-30 19:01:17 [multiproc_executor.py:709] Parent process exited, terminating worker
^[[0;36m(Worker_TP2 pid=698527)^[[0;0m INFO 12-30 19:01:17 [multiproc_executor.py:709] Parent process exited, terminating worker
^[[0;36m(Worker_TP3 pid=698528)^[[0;0m INFO 12-30 19:01:17 [multiproc_executor.py:709] Parent process exited, terminating worker
^[[0;36m(Worker_TP0 pid=698525)^[[0;0m INFO 12-30 19:01:17 [multiproc_executor.py:709] Parent process exited, terminating worker
APC: True 2 Generated text: " __________, and I am the Head of School at __________. I am thrilled to welcome you to our interview today. Our school is a K-12, all-girls', independent school that emphasizes community, joyful discovery, and lifelong learning. We believe that every girl has the potential to thrive when nurtured in an environment that values curiosity, collaboration, and courage. As we consider candidates for our 8th grade Math teaching role, we are looking for educators who not only have strong content knowledge and pedagogical skill, but who also bring a deep commitment to fostering a classroom culture where girls feel seen, heard, and"
vllm:prefix_cache_hits 64192
APC: True loop took --- 2.0605297088623047 seconds ---

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

peakcrosser7 and others added 30 commits November 24, 2025 14:29
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>

@tdoublep tdoublep left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! This feature enables prefix caching for a broader set of models.

Let's fix the issues that remain for MTP as a follow-up.

Comment thread vllm/v1/kv_cache_interface.py Outdated
Comment on lines +283 to +285
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is redundant now I think

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. We can remove it later.

Comment on lines +295 to +296
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this code for "all" is wrong actually, but it is not an issue introduced by this PR. Will fix it as a follow-up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. It seems "all" mode performs allocation at the granularity of mamba_block_size, so we need to fix this later.

@heheda12345 heheda12345 merged commit 5206e5e into vllm-project:main Jan 23, 2026
64 of 65 checks passed
@MatthewBonanni MatthewBonanni mentioned this pull request Jan 23, 2026
5 tasks
@MatthewBonanni

MatthewBonanni commented Jan 23, 2026

Copy link
Copy Markdown
Member

This PR appears to fail pre-commit, I have a fix: #32956

cwazai pushed a commit to cwazai/vllm that referenced this pull request Jan 25, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: 陈建华 <1647430658@qq.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
MengqingCao pushed a commit to vllm-project/vllm-ascend that referenced this pull request Mar 15, 2026
…de align` (#7103)

### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](vllm-project/vllm#30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: Angazenn <supperccell@163.com>
Nagisa125 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 17, 2026
…de align` (vllm-project#7103)

### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](vllm-project/vllm#30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: Angazenn <supperccell@163.com>
yangzhe-2026 pushed a commit to yangzhe-2026/vllm-ascend that referenced this pull request May 6, 2026
…de align` (vllm-project#7103)

### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](vllm-project/vllm#30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: Angazenn <supperccell@163.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
nanxingMy pushed a commit to nanxingMy/vllm-ascend that referenced this pull request May 15, 2026
…de align` (vllm-project#7103)

### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](vllm-project/vllm#30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: nanxing <1014662416@qq.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
0826joyce pushed a commit to 0826joyce/vllm-serving-optimization that referenced this pull request May 19, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants