Skip to content

[V1] [Hybrid] Move MiniMaxLinearAttention into layers/mamba#23831

Merged
vllm-bot merged 4 commits into
vllm-project:mainfrom
tdoublep:move-linear-attn
Aug 30, 2025
Merged

[V1] [Hybrid] Move MiniMaxLinearAttention into layers/mamba#23831
vllm-bot merged 4 commits into
vllm-project:mainfrom
tdoublep:move-linear-attn

Conversation

@tdoublep

@tdoublep tdoublep commented Aug 28, 2025

Copy link
Copy Markdown
Member

Purpose

This PR just moves the MinimaxLinearAttention layer into the layers/mamba directory, to be consistent with the other mamba-like layers (e.g., mamba1, mamba2, short_conv).

cc @heheda12345

Test Plan

Let's see if CI is OK (works locally).

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully moves the MinimaxLinearAttention layer and its related components to the vllm/model_executor/layers/mamba/ directory, which improves code organization. However, this move introduces a critical circular dependency between the layers and models packages. Additionally, I've identified a pre-existing critical bug in the moved MiniMaxText01RMSNormTP class related to weight handling. Both issues should be addressed.

import torch
import torch.distributed

from vllm.model_executor.models.minimax_cache import MinimaxCacheParams

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This import introduces a circular dependency. The layers module should not depend on the models module. Specifically, vllm.model_executor.models.minimax_text_01 now imports MiniMaxText01LinearAttention from this file (layers/mamba/linear_attn.py), which in turn imports MinimaxCacheParams from models/minimax_cache.py. This creates a models -> layers -> models dependency cycle, which can lead to module resolution issues and makes the codebase harder to maintain.

To resolve this, MinimaxCacheParams should be moved to a more foundational location that both layers and models can depend on, for example vllm/model_executor/layers/mamba/mamba_utils.py.

weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for repeat_count is incorrect. It should perform a ceiling division of x.size(-1) by self.weight.size(0) to determine how many times the weight tensor needs to be repeated to match the input tensor's dimension. The current logic will result in repeat_count being 1 if self.weight.size(0) < x.size(-1), which will lead to incorrect behavior or a runtime error due to shape mismatch during the multiplication.

Suggested change
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
repeat_count = (x.size(-1) + self.weight.size(0) - 1) // self.weight.size(0)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

@heheda12345 heheda12345 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep model-specific layers in model.py but open to further discussion. There is also a Plamo2MambaMixer in vllm/model_executor/models/plamo2.py (and maybe more, I didn't check the full list).

LucasWilkinson
LucasWilkinson previously approved these changes Aug 28, 2025

@LucasWilkinson LucasWilkinson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thanks for doing this; left one nit


if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
pass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we get rid of this block if its not used anymore

@LucasWilkinson

Copy link
Copy Markdown
Collaborator

I prefer to keep model-specific layers in model.py but open to further discussion. There is also a Plamo2MambaMixer in vllm/model_executor/models/plamo2.py (and maybe more, I didn't check the full list).

Oh fair; good point! ya idk do we expect more MiniMax models using this to come?

@LucasWilkinson LucasWilkinson dismissed their stale review August 28, 2025 19:28

wait for @heheda12345 to comment

@heheda12345

Copy link
Copy Markdown
Collaborator

Oh fair; good point! ya idk do we expect more MiniMax models using this to come?

I don't think this module will be reuse unless minimax team has new model release.

@tdoublep

tdoublep commented Aug 29, 2025

Copy link
Copy Markdown
Member Author

I prefer to keep model-specific layers in model.py but open to further discussion.

OK - if that is the preferred approach then should we move short_conv into the Lfm2 modeling file then?

My thinking here was it would be useful to have all of these ops in one place so we can then start to look for commonalities (e.g., to have something like unified_attention but for mamba-like ops).

There is also a Plamo2MambaMixer in vllm/model_executor/models/plamo2.py (and maybe more, I didn't check the full list).

This model has a number of issues - it's on my to-do list to take a look at it.

@heheda12345

Copy link
Copy Markdown
Collaborator

My thinking here was it would be useful to have all of these ops in one place so we can then start to look for commonalities (e.g., to have something like unified_attention but for mamba-like ops).

Sounds good!

@heheda12345 heheda12345 enabled auto-merge (squash) August 29, 2025 19:55
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 29, 2025
@DarkLight1337 DarkLight1337 disabled auto-merge August 29, 2025 23:54
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 29, 2025 23:54
@vllm-bot vllm-bot merged commit 4071c76 into vllm-project:main Aug 30, 2025
37 of 43 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
…ject#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
ABC12345anouys pushed a commit to ABC12345anouys/vllm that referenced this pull request Sep 25, 2025
…ject#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
…ject#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…ject#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…ject#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
0826joyce pushed a commit to 0826joyce/vllm-serving-optimization that referenced this pull request May 19, 2026
…ject#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants