Skip to content

Add Mistral Small 4 (Pixtral) support#20708

Merged
Kangyan-Zhou merged 30 commits intosgl-project:mainfrom
JustinTong0323:mistral4-support
Mar 18, 2026
Merged

Add Mistral Small 4 (Pixtral) support#20708
Kangyan-Zhou merged 30 commits intosgl-project:mainfrom
JustinTong0323:mistral4-support

Conversation

@JustinTong0323
Copy link
Copy Markdown
Collaborator

@JustinTong0323 JustinTong0323 commented Mar 16, 2026

Summary

  • Add Mistral Small 4 (119B) model support, reusing the MistralLarge3/DeepSeekV3 backend with Pixtral vision encoder
  • Handle Mistral-native config format (params.json) for Mistral Small 4 and LeanStral model variants
  • Add Mistral reasoning parser ([THINK]/[/THINK] format) with reasoning_effort="high" gating
  • Fix Pixtral vision processor: proper spatial_merge_size handling, rope_parameters compatibility, and fallback PixtralProcessor wrapping when processor_config.json is missing
  • Load chat_template.jinja from model repo when tokenizer has no chat template
  • Workaround Mistral tokenizer marking [THINK]/[/THINK] as special tokens (upstream issue), which causes skip_special_tokens=True to strip reasoning markers before the parser can see them

Co-authored-by: Alex Nails alexnails@radixark.ai

Usage

# FP8
python -m sglang.launch_server \
  --model-path mistralai/Mistral-Small-4-119B-2603 \
  --tp 2 \
  --reasoning-parser mistral \
  --tool-call-parser mistral

# NVFP4
python -m sglang.launch_server \
  --model-path mistralai/Mistral-Small-4-119B-2603-NVFP4 \
  --tp 2 \
  --reasoning-parser mistral \
  --tool-call-parser mistral

Eval results (GSM8K)

Checkpoint GSM8K Accuracy
Mistral-Small-4-119B-2603 (FP8) 0.835
Mistral-Small-4-119B-2603-NVFP4 0.826

Test plan

  • Verify mistralai/Mistral-Small-4-119B-2603 loads and generates correct output with --tp 2
  • GSM8K eval on FP8 (0.835) and NVFP4 (0.826)
  • Verify --reasoning-parser mistral correctly extracts [THINK]/[/THINK] blocks into reasoning_content
  • Verify reasoning_effort="high" triggers thinking, "none" does not
  • Verify tool calls (single and multi) work with --tool-call-parser mistral
  • Verify streaming (chat, reasoning, tool calls)
  • Verify vision (image) inputs work through the Pixtral processor

JustinTong0323 and others added 15 commits February 28, 2026 13:57
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
…size

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
…processor

- Use patch_size * spatial_merge_size as the effective patch size in
  PixtralImageProcessor so images resize to multiples of 28 (not 14),
  matching PatchMerger requirements with spatial_merge_size=2
- Remove manual _resize and get_patch_grid_size methods, relying on
  the correctly configured HF image processor instead
- Add multi-image offset splitting for per-image MultimodalDataItem
- Remove unused torch import
- Add --model flag (default "default") to avoid hardcoded model name
- Add --reasoning-effort flag passed as top-level request field
- Support local image paths via base64 data URI encoding
- Pass reasoning_effort and model as explicit parameters instead of
  smuggling through sampling_params dict
…riable

The flashinfer trtllm_fp8_per_tensor_scale_moe already defaults activation_type
to Swiglu (3), which matches Mistral-Small-4's silu+gated config. Also replace
unused ncols with _ in pixtral processor.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

JustinTong0323 and others added 2 commits March 16, 2026 17:20
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>

tokenizer = get_tokenizer_from_processor(processor)

if tokenizer.chat_template is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we keep this? (I actually think this is a useful fallback but it should be improved at a later point)

Comment thread python/sglang/srt/parser/reasoning_parser.py Outdated
Comment thread benchmark/mmmu/eval_utils.py
The EAGLE draft model for Mistral Small 4 (mistralai/Mistral-Small-4-119B-2603-eagle)
uses dense MLA layers without MoE, unlike the Mistral Large 3 EAGLE which has MoE.
This caused three issues:

1. `adapt_config_dict` in mistral_utils.py did not handle dense EAGLE models
   (moe=null in params.json), falling through to an unsupported architecture.
   Fix: add a branch for `is_eagle and not is_moe` that sets model_type=deepseek_v3
   with all-dense MoE overrides (first_k_dense_replace=num_layers).

2. `_remap_mistral_yarn_args` did not include rope_theta in rope_scaling,
   causing transformers yarn validation to fail.
   Fix: copy rope_theta into the rope_scaling dict.

3. `MistralLarge3ForCausalLMEagle.__init__` set `self.model_cls` but
   `DeepseekV2ForCausalLM.__init__` hardcodes `self.model = DeepseekV2Model`,
   so the EAGLE fc layer was never created. The draft model ran without fusing
   token embeddings with target hidden states, producing garbage draft tokens
   (accept rate 0.25).
   Fix: call super().__init__() then replace self.model with
   MistralLarge3EagleModel which has the fc layer. Accept rate: 0.25 -> 0.83.
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

5 similar comments
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@alexnails
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@alexnails
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@alexnails
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@dbari
Copy link
Copy Markdown
Contributor

dbari commented Mar 17, 2026

Here is a diff to improve the gsm8k score:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9105 ± 0.0079
strict-match 5 exact_match 0.9083 ± 0.0080
diff
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 8f0617142..4750a6532 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -1198,7 +1198,7 @@ class DeepseekV2AttentionMLA(
                 device=get_global_server_args().device,
             )
 
-            if rope_scaling:
+            if rope_scaling and rope_scaling.get("apply_yarn_scaling", True):
                 mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
                 scaling_factor = rope_scaling["factor"]
                 mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py
index 4955c0575..dc9e08d94 100644
--- a/python/sglang/srt/utils/mistral_utils.py
+++ b/python/sglang/srt/utils/mistral_utils.py
@@ -134,11 +134,11 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
         "original_max_position_embeddings": "original_max_position_embeddings",
         "beta": "beta_fast",
         "alpha": "beta_slow",
-        "apply_scale": None,
+        "apply_scale": "apply_yarn_scaling",
     }
     yarn_config = config.get("yarn") or {}
     config["rope_scaling"] = {
-        "rope_type": "yarn",
+        "rope_type": "deepseek_yarn",
         "mscale_all_dim": 1,
     }
     # Include rope_theta in rope_scaling if present at the top level,

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

Mistral Small 4's params.json sets "apply_scale": false in the yarn
config, meaning the mscale factor should NOT be applied to attention
logits scaling. Previously this field was discarded, causing an
incorrect 2.2x mscale to be applied unconditionally.

Changes:
- Map "apply_scale" to "apply_yarn_scaling" in rope_scaling dict
  instead of dropping it
- Use "deepseek_yarn" rope_type to avoid transformers yarn validation
  issues
- Gate mscale application in DeepseekV2AttentionMLA on apply_yarn_scaling

gsm8k 5-shot exact_match: 0.7976 -> 0.8901 (+9.3%)
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

@dbari I've just pushed the fix you made on the rope. Thanks a lot for that! I also apologize for the earlier wrong decision of not including this fix.

@alexnails
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@Kangyan-Zhou Kangyan-Zhou merged commit 6b8a654 into sgl-project:main Mar 18, 2026
29 of 46 checks passed
Qiaolin-Yu added a commit that referenced this pull request Mar 18, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Alex Nails <alexnails@radixark.ai>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: dbari <dbari@users.noreply.github.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Alex Nails <alexnails@radixark.ai>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: dbari <dbari@users.noreply.github.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Alex Nails <alexnails@radixark.ai>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: dbari <dbari@users.noreply.github.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 added a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Alex Nails <alexnails@radixark.ai>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: dbari <dbari@users.noreply.github.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Alex Nails <alexnails@radixark.ai>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: dbari <dbari@users.noreply.github.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants