Skip to content

[NVFP4] NVFP4 MOE emulation fallback for H100/MI300/MI350, standardize TritonExperts usage for OCP MX emulation#35737

Merged
vllm-bot merged 65 commits into
vllm-project:mainfrom
fxmarty-amd:upstream-nvfp4-simulation-support-moe
Apr 22, 2026
Merged

[NVFP4] NVFP4 MOE emulation fallback for H100/MI300/MI350, standardize TritonExperts usage for OCP MX emulation#35737
vllm-bot merged 65 commits into
vllm-project:mainfrom
fxmarty-amd:upstream-nvfp4-simulation-support-moe

Conversation

@fxmarty-amd

@fxmarty-amd fxmarty-amd commented Mar 2, 2026

Copy link
Copy Markdown
Contributor

Purpose

This PR enables running NVFP4 MOE models on Hopper and AMD Instinct MI300, MI350.

This is useful for researchers, anybody trying out microscaling formats, and people who would like to run e.g. https://huggingface.co/nvidia/Qwen3-30B-A3B-NVFP4 or https://huggingface.co/RedHatAI/Qwen3-30B-A3B-NVFP4 on non-Blackwell devices.

This PR also refactors quark_moe.py to stop using the functional fused_experts function for OCP MX quantization emulation, and instead purely rely on TritonExperts.

Test Plan

See

  • CUDA_VISIBLE_DEVICES="0,1" pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-mi3xx.txt running on AMD Instinct MI325X (MXFP4 & NVFP4 emulation fallback) and passing.

✅ GSM8K test passed for amd/Qwen3.5-35B-A3B-MXFP4
✅ GSM8K test passed for nvidia/Qwen3-30B-A3B-FP4

  • pytest tests/models/quantization/test_nvfp4.py -s -vvvvv -k "test_nvfp4_moe" running on MI355X.
  • CUDA_VISIBLE_DEVICES="6,7" pytest tests/quantization/test_quark.py -s -vvvvv -k "test_ocp_mx_wikitext_correctness" (testing MXFP4/MXFP6 Qwen MOE emulation)

And see as of 1e1d139

export PRETRAINED_PATH="Qwen/Qwen3-30B-A3B"

CUDA_VISIBLE_DEVICES=4 nohup lm_eval \
  --model vllm \
  --model_args '{"pretrained":"'"${PRETRAINED_PATH}"'","dtype":"auto","tensor_parallel_size":1,"enable_thinking": false,"chat_template_args":{"enable_thinking":false},"gpu_memory_utilization":0.8}' \
  --device "cuda" \
  --tasks wikitext,piqa \
  --batch_size auto &> lm_eval_bf16.log &

giving:

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|piqa    |      1|none  |     0|acc            |↑  | 0.7905|±  |0.0095|
|        |       |none  |     0|acc_norm       |↑  | 0.8052|±  |0.0092|
|wikitext|      2|none  |     0|bits_per_byte  |↓  | 0.6444|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  | 1.5631|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |10.8968|±  |   N/A|
export PRETRAINED_PATH="nvidia/Qwen3-30B-A3B-NVFP4"

CUDA_VISIBLE_DEVICES=4 nohup lm_eval \
  --model vllm \
  --model_args '{"pretrained":"'"${PRETRAINED_PATH}"'","dtype":"auto","tensor_parallel_size":1,"enable_thinking": false,"chat_template_args":{"enable_thinking":false},"gpu_memory_utilization":0.8}' \
  --device "cuda" \
  --tasks wikitext,piqa \
  --batch_size auto &> lm_eval_modelopt.log &
(EngineCore pid=910813) INFO 04-09 13:00:57 [nvfp4_utils.py:150] Using NvFp4LinearBackend.EMULATION for NVFP4 GEMM
(EngineCore pid=910813) INFO 04-09 13:00:58 [nvfp4.py:285] Using 'EMULATION' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTEDSL_BATCHED', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN', 'EMULATION'].
...
| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|piqa    |      1|none  |     0|acc            |↑  | 0.7818|±  |0.0096|
|        |       |none  |     0|acc_norm       |↑  | 0.7933|±  |0.0094|
|wikitext|      2|none  |     0|bits_per_byte  |↓  | 0.6531|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  | 1.5726|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |11.2551|±  |   N/A|

And

export PRETRAINED_PATH="RedHatAI/Qwen3-30B-A3B-NVFP4"

CUDA_VISIBLE_DEVICES=7 nohup lm_eval \
  --model vllm \
  --model_args '{"pretrained":"'"${PRETRAINED_PATH}"'","dtype":"auto","tensor_parallel_size":1,"enable_thinking": false,"chat_template_args":{"enable_thinking":false},"gpu_memory_utilization":0.8}' \
  --device "cuda" \
  --tasks wikitext,piqa \
  --batch_size auto &> lm_eval_ct.log &

gives

(EngineCore pid=915159) INFO 04-09 13:01:13 [nvfp4_utils.py:150] Using NvFp4LinearBackend.EMULATION for NVFP4 GEMM
(EngineCore pid=915159) INFO 04-09 13:01:14 [nvfp4.py:285] Using 'EMULATION' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTEDSL_BATCHED', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN', 'EMULATION'].
...
| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|piqa    |      1|none  |     0|acc            |↑  | 0.7878|±  |0.0095|
|        |       |none  |     0|acc_norm       |↑  | 0.7889|±  |0.0095|
|wikitext|      2|none  |     0|bits_per_byte  |↓  | 0.6557|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  | 1.5753|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |11.3622|±  |   N/A|

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@mergify mergify Bot added nvidia rocm Related to AMD ROCm labels Mar 2, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 2, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 MOE models on a wider range of hardware, including AMD Instinct, Nvidia Ampere, and Hopper, through an emulation backend. The changes are extensive, touching quantization layers, model execution, and tests to accommodate this new emulation path. The implementation appears solid and well-integrated. I've found one critical issue that needs to be addressed.

Comment thread vllm/model_executor/layers/fused_moe/quantization_emulation_moe.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/quantization_emulation_moe.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/config.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/utils.py Outdated
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@fxmarty-amd fxmarty-amd requested a review from kylesayrs April 16, 2026 11:42
Comment thread tests/models/quantization/test_nvfp4.py Outdated
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels Apr 17, 2026
@mergify

mergify Bot commented Apr 17, 2026

Copy link
Copy Markdown
Contributor

Hi @fxmarty-amd, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 17, 2026
@fxmarty-amd

Copy link
Copy Markdown
Contributor Author

Failing tests are:

https://buildkite.com/vllm/ci/builds/62078/steps/canvas?jid=019da9d5-f2c5-41df-a6f0-7d60c43bf5fb&tab=output#019da9d5-f2c5-41df-a6f0-7d60c43bf5fb

[2026-04-20T08:15:44Z] FAILED compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-True-dtype0] - AssertionError: Tensor-likes are not close!
[2026-04-20T08:15:44Z]
[2026-04-20T08:15:44Z] Mismatched elements: 64 / 2048 (3.1%)
[2026-04-20T08:15:44Z] Greatest absolute difference: 0.7900390625 at index (15, 6) (up to 1e-05 allowed)
[2026-04-20T08:15:44Z] Greatest relative difference: 2.34375 at index (15, 0) (up to 0.001 allowed)
[2026-04-20T08:15:44Z]
[2026-04-20T08:15:44Z] The failure occurred for item [0]
[2026-04-20T08:15:44Z] ===== 1 failed, 361 passed, 155 skipped, 51 warnings in 1331.73s (0:22:11) =====

https://buildkite.com/vllm/ci/builds/62078/steps/canvas?jid=019da9d5-f276-45e6-99b5-8917e3d888b5&tab=output#019da9d5-f276-45e6-99b5-8917e3d888b5


�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/oracle/fp8.py", line 568, in make_fp8_moe_kernel
�_bk;t=1776672180490     prepare_finalize = maybe_make_prepare_finalize(
�_bk;t=1776672180490                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/all2all_utils.py", line 139, in maybe_make_prepare_finalize
�_bk;t=1776672180490     handle = all2all_manager.get_handle(all_to_all_args)
�_bk;t=1776672180490              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/all2all.py", line 354, in get_handle
�_bk;t=1776672180490     handle: deep_ep.Buffer = self.handle_cache.get_or_create(
�_bk;t=1776672180490                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 23, in get_or_create
�_bk;t=1776672180490     instance = func(**kwargs)
�_bk;t=1776672180490                ^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/deep_ep/buffer.py", line 95, in __init__
�_bk;t=1776672180490     device_ids = all_gather_object(local_device_id)
�_bk;t=1776672180490                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/deep_ep/buffer.py", line 74, in all_gather_object
�_bk;t=1776672180490     dist.all_gather_object(object_list, obj, group)
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
�_bk;t=1776672180490     return func(*args, **kwargs)
�_bk;t=1776672180490            ^^^^^^^^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 3312, in all_gather_object
�_bk;t=1776672180490     all_gather(object_size_list, local_size, group=group)
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
�_bk;t=1776672180490     return func(*args, **kwargs)
�_bk;t=1776672180490            ^^^^^^^^^^^^^^^^^^^^^
�_bk;t=1776672180490   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4086, in all_gather
�_bk;t=1776672180490     work.wait()
�_bk;t=1776672180490 RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:547] Connection closed by peer [10.48.60.108]:8330
�_bk;t=1776672180490 subtest: [1-128-256-8-6-bfloat16-modelopt_fp8-True-True-False-False-False-deepep_high_throughput-2-2-1]
�_bk;t=1776672180490 [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:547] Connection closed by peer [10.48.60.108]:8330
�_bk;t=1776672180490 FAILED <class 'RuntimeError'>

=================================================================== short test summary info ====================================================================
FAILED kernels/moe/test_moe_layer.py::test_moe_layer[False-deepep_high_throughput-2-1-True] - torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGABRT
============================================= 1 failed, 144 passed, 279 skipped, 18 warnings in 372.90s (0:06:12) ==============================================

which look unrelated?

@vllm-bot vllm-bot merged commit d622e27 into vllm-project:main Apr 22, 2026
71 of 73 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 22, 2026
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 22, 2026
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Apr 23, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Yifan <yzong@redhat.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Adrian <info@zzit.ch>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
brian-dellabetta pushed a commit to neuralmagic/vllm that referenced this pull request May 29, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…e `TritonExperts` usage for OCP MX emulation (vllm-project#35737)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia quantization ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants