Skip to content

[Quantization] Rework quantization_config to use QuantKey and allow for activation override#41566

Merged
mgoin merged 7 commits into
vllm-project:mainfrom
neuralmagic:quantization-config-rework
May 13, 2026
Merged

[Quantization] Rework quantization_config to use QuantKey and allow for activation override#41566
mgoin merged 7 commits into
vllm-project:mainfrom
neuralmagic:quantization-config-rework

Conversation

@mgoin

@mgoin mgoin commented May 3, 2026

Copy link
Copy Markdown
Member

Purpose

Replaces the OnlineQuantScheme enum and global_scheme/linear_scheme_override/moe_scheme_override fields with a QuantSpec(weight, activation) per layer kind, addressable by name from a QUANT_KEY_NAMES table. The CLI shorthands (fp8_per_tensor, fp8_per_block, mxfp8, int8_per_channel_weight_only) keep working and now desugar into the new structure leveraging QuantKey pairs directly.

Routes --quantization_config.moe.activation through the MXFP4 MoE oracle so users can opt into MXFP8 activations on already-quantized checkpoints (gpt-oss) without setting env vars. When combined with --moe-backend flashinfer_cutlass or flashinfer_trtllm, the oracle upgrades the BF16 backend variant to its MXFP8 counterpart.

Also migrates the gpt-oss eval configs and test_blackwell_moe MXFP8 tests off VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8* env vars and onto the new flag to test this

Adds tests/quantization/test_quantization_config_args.py for the parsing UX

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify

mergify Bot commented May 3, 2026

Copy link
Copy Markdown
Contributor

Documentation preview: https://vllm--41566.org.readthedocs.build/en/41566/

@mergify mergify Bot added documentation Improvements or additions to documentation frontend nvidia labels May 3, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the online quantization configuration system, replacing the OnlineQuantScheme enum with a more flexible QuantizationConfigArgs structure. It renames linear_scheme_override and moe_scheme_override to linear and moe, respectively, and introduces QuantSpec to allow independent control over weight and activation quantization. Additionally, the PR migrates activation selection logic (e.g., for MXFP8) from environment variables to the explicit quantization_config and enables user-defined overrides to coexist with checkpoint-based quantization. I have no feedback to provide.

@mgoin mgoin changed the title Rework quantization_config to per-layer-kind QuantSpec Rework quantization_config to use QuantKey and allow for activation override May 3, 2026
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels May 3, 2026
Comment thread tests/compile/fusions_e2e/conftest.py Outdated

@BowenBao BowenBao left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mgoin ! I added some suggestions

Comment thread vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Comment on lines +361 to +377
# Upgrade BF16-act variants to MXFP8-act variants when the user
# requests MXFP8 activations via quantization_config.
if activation_qk == kMxfp8Dynamic:
requested_backend = {
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16: (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
),
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16: (
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8
),
}.get(requested_backend, requested_backend)
backend_act = _backend_activation_key(requested_backend)
if activation_qk is not None and activation_qk != backend_act:
raise ValueError(
f"moe_backend={runner_backend!r} runs with activation="
f"{backend_act}, but activation={activation_qk} was requested"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we leave it to is_supported_config to find the matching config? as long as user overriden activation_key is passed, it should find the right backend.

            supported, reason = k_cls.is_supported_config(
                k_cls, config, kMxfp4Static, activation_key, activation_format
            )

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly addressed. The runner_backend path now drops the manual upgrade dict and just iterates the candidate list from map_mxfp4_backend, leaving variant selection to is_supported_config. The one wrinkle: _supports_quant_scheme on TrtLlmMxfp4ExpertsMonolithic accepts both (kMxfp4Static, None) and (kMxfp4Static, kMxfp8Dynamic) for the same kernel class, so when both FLASHINFER_TRTLLM_MXFP4_BF16 and FLASHINFER_TRTLLM_MXFP4_MXFP8 are candidates, is_supported_config accepts both and the first listed one wins. To match the user's requested activation deterministically, I added a thin _activation_matches(backend) = _backend_activation_key(backend) == requested_activation_key filter at the oracle level (only when an override is set). Happy to push that filter down into the kernels in a follow-up if you'd prefer.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accepts both (kMxfp4Static, None) and (kMxfp4Static, kMxfp8Dynamic) for the same kernel class

I ran into a similar situation in #41436 for aiter backend, where w4a4 and w4a16 are both supported. There to resolve it I create a separate shim for w4a4. Either solution works for me, happy to hear your thoughts!

Comment thread vllm/model_executor/layers/fused_moe/oracle/mxfp4.py Outdated
Comment thread tests/compile/fusions_e2e/conftest.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/oracle/mxfp4.py Outdated
Replaces the OnlineQuantScheme enum and global_scheme/linear_scheme_override/moe_scheme_override fields with a QuantSpec(weight, activation) per layer kind, addressable by name from a hand-curated QUANT_KEY_NAMES table.

Routes quantization_config.moe.activation through the MXFP4 MoE oracle so users can opt into MXFP8 activations on already-quantized checkpoints (gpt-oss) without setting env vars. The oracle combines the override with any caller-supplied activation_key from vllm-project#39136 and raises on conflict.

map_mxfp4_backend now returns a list of candidate backends per vendor name; both the runner_backend path and the priority list filter by activation_key match.

Eval configs and test_blackwell_moe migrate to --quantization-config.moe.activation mxfp8. weight_utils stops rejecting quantization_config alongside a checkpoint quant config. --quantization-config is registered as a CLI arg.

Adds parsing UX unit tests and expands docs/features/quantization/online.md.

Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin force-pushed the quantization-config-rework branch from bdb8a4c to 5be803c Compare May 8, 2026 15:23
mgoin added 4 commits May 8, 2026 11:37
is_supported_config alone cannot distinguish backend variants that share an experts class (TRTLLM_BF16 and TRTLLM_MXFP8 both use TrtLlmMxfp4ExpertsMonolithic), so iterating the candidate list with the override propagated through act_key would silently pick the BF16 backend running MXFP8 activations.

When no activation is requested and the candidate list contains a BF16 alternative, drop the non-BF16 entries; otherwise keep them (so flashinfer_trtllm_afp8 -> [TRTLLM_MXFP8] still picks MXFP8). When an activation is requested, keep only candidates whose intrinsic activation matches.

Signed-off-by: mgoin <mgoin64@gmail.com>
Per Luka: gpt-oss model_kwargs setup belongs in tests/compile/fusions_e2e/models.py keyed on is_blackwell(), not as a special case in conftest.

Trim verbose comments and docstrings added in this PR; let docs/features/quantization/online.md be the canonical reference instead of repeating format names inline where they will go stale.

Signed-off-by: mgoin <mgoin64@gmail.com>
Drops the duplicated _make_log_backend / _make_log_unsupported / _return_or_raise nested closures from select_mxfp4_moe_backend and select_deepseek_v4_mxfp4_moe_backend. _return_or_raise gains a scope parameter (defaulting to logger.info_once default of "local") so the deepseek path keeps its existing scope="local" logging.

_filter_by_activation also moves to module scope and takes requested_activation_key as a parameter instead of capturing it.

Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin changed the title Rework quantization_config to use QuantKey and allow for activation override [Quantization] Rework quantization_config to use QuantKey and allow for activation override May 12, 2026
@mergify

mergify Bot commented May 12, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 12, 2026
…ework

# Conflicts:
#	vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
"mxfp8": kMxfp8Dynamic,
"mxfp4": kMxfp4Dynamic,
"int8_per_channel_static": kInt8StaticChannelSym,
}

@juhi10071998 juhi10071998 May 12, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @mgoin - for our usecase we wanted to ingest the nvfp4 ckpt (weights and activation in nvfp4- input_scales in ckpt) as the W4A16_NVFP4 and use Marlin backend (i.e leave the activation in bf16 and ignore input_scale).

Is it possible to add a key for bfloat16/ float16 in this so we can leverage this to set the activation key in fusedMoE ModelOpt class

CT ckpts config have the input_dtype which sets the use_a16 but prevents using pure nvfp4 ckpt as-is w/o changing the config manually.

We took a parallel path in our open draft #42428 and introduced the override_activation_dtype flag which gets passed to ModelOptNVFP4Config from BaseConfig and eventually to the [ModelOptNvFp4FusedMoE class] for changing activation_key here

Appreciate your insights if we should have the override flag and your QuantSpec surface co-exist or just add the bfloat16/ float16 keys here.

Should we call the activation_override as modelopt_override_activation_dtype instead of override_activation_dtype?

@pavanimajety

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to use the marlin backend with an existing NVFP4 W4A4 model, you can just use --linear-backend marlin or --moe-backend marlin. Is that sufficient?
We can consider allowing overriding the checkpoint activation type with the path added in this PR too

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Michael, we may want to have a finer control than the --moe-backed marlin in my understanding.

When you say added in this PR- does it refer to the current PR you have, or introduce the new override_activation_dtype standalone flag from cli?

…ework

# Conflicts:
#	vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
#	vllm/model_executor/layers/quantization/online/base.py

@BowenBao BowenBao left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 13, 2026
@mgoin mgoin merged commit 8efd508 into vllm-project:main May 13, 2026
92 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 13, 2026
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
rishitdholakia13 pushed a commit to rishitdholakia13/vllm that referenced this pull request May 19, 2026
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
h1t35h pushed a commit to h1t35h/vllm that referenced this pull request May 21, 2026
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…or activation override (vllm-project#41566)

Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend nvidia quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants