Skip to content

support Hy3 preview#23533

Merged
Qiaolin-Yu merged 36 commits intosgl-project:mainfrom
JustinTong0323:support-hy3-preview
Apr 24, 2026
Merged

support Hy3 preview#23533
Qiaolin-Yu merged 36 commits intosgl-project:mainfrom
JustinTong0323:support-hy3-preview

Conversation

@JustinTong0323
Copy link
Copy Markdown
Collaborator

@JustinTong0323 JustinTong0323 commented Apr 23, 2026

Summary

Add support for Tencent Hunyuan V3 (Hy3-preview) models in sglang.

Components

  • Model: python/sglang/srt/models/hunyuan_v3.py (+ MTP variant hunyuan_v3_nextn.py)
  • Tool-call parser: python/sglang/srt/function_call/hunyuan_detector.py — streaming HYV3 tool parser with <tool_calls>/<tool_call>/<tool_sep>/<arg_key>/<arg_value> format, schema-aware type coercion, and char-by-char streaming for string args
  • Reasoning parser: HunyuanDetector in python/sglang/srt/parser/reasoning_parser.py covering <think>...</think>; dispatch in serving_chat.py treats reasoning_effort values high / low as reasoning-enabled and none / no_think / unset as reasoning-disabled
  • Docs: docs/basic_usage/hy3_preview.md — launch, reasoning_effort semantics, cURL examples
  • Tuned MoE configs: H20 / H20-3e (140 GB) bf16 configs
  • Tests: test/registered/unit/function_call/test_hunyuan_detector.py (50 tests) covering streaming, typed-arg coercion, partial-end-tag hold-back, char-by-char value streaming, and reference parity cases

Usage

client = OpenAI(base_url="http://localhost:30000/v1", api_key="EMPTY")

# Non-thinking mode (default)
resp = client.chat.completions.create(model="hy3-preview", messages=messages)

# Thinking mode — set reasoning_effort in chat_template_kwargs
resp = client.chat.completions.create(
    model="hy3-preview",
    messages=messages,
    extra_body={
        "chat_template_kwargs": {
            "reasoning_effort": "high",   # "high" / "low" / "no_think"
            "interleaved_thinking": True,
        }
    },
)
print(resp.choices[0].message.reasoning_content)
print(resp.choices[0].message.content)

Test plan

  • pytest test/registered/unit/function_call/test_hunyuan_detector.py — 50 / 50 passing
  • End-to-end verified on Hy3-preview (TP=4, --tool-call-parser hunyuan --reasoning-parser hunyuan) with MiniMax-Provider-Verifier: 100 % Query-Success, 98.0 % ToolCalls-Match, 96.4 % Schema-Accuracy, 100 % Response-Success, 100 % Language-Following
  • pre-commit run --all-files — clean on touched files

mpjlu and others added 30 commits March 25, 2026 11:40
Add support for Hunyuan model's custom tool call format and reasoning
parsing, enabling --tool-call-parser hunyuan and --reasoning-parser hunyuan.

Tool call detector (hunyuan_detector.py):
- Parses Hunyuan's XML-like format: <tool_calls>/<tool_call>/<tool_sep>/
  <arg_key>/<arg_value> tags
- Supports both streaming and non-streaming parsing
- Type-aware argument deserialization (string, int, float, bool)

Reasoning parser (HunyuanDetector in reasoning_parser.py):
- Uses <think>/<\/think> tags with tool_start_token=<tool_calls>
- Handles tool call interruption during reasoning

reasoning_effort integration (serving_chat.py):
- Hunyuan defaults to no-think mode; reasoning_effort=low/high enables
  thinking and separates reasoning_content from content

Verified on tencent/HY3.0-FP8-Testing with MiniMax-Provider-Verifier:
- Query-Success-Rate: 100%
- ToolCalls-Schema-Accuracy: 95.12%
- Response-Success-Rate: 100%
Add Hunyuan tool call parser and reasoning parser
- config.expert_hidden_dim -> config.moe_intermediate_size
- config.qk_norm -> config.use_qk_norm (QK-norm always enabled)
- Fix shared_mlp prefix to match checkpoint weight names
This reverts commit a753002, reversing
changes made to 8819588.
Rewrites HunyuanDetector.parse_streaming_increment to emit true
incremental deltas instead of atomic per-block ToolCallItems.

- Phase 1 emits the tool name as soon as <tool_sep> is seen.
- Phase 2 streams argument JSON incrementally. Pure-string args may be
  streamed char-by-char with JSON escaping; non-string values are
  coerced via a schema-aware fallthrough (bool > int > number >
  json.loads for array/object > string) that supports anyOf/oneOf and
  prefix-based type normalization (int*/uint*/num*/float*/...).
- The trailing "}" is withheld until </tool_call> arrives, and any
  suffix of partial_value that could form a </arg_value> end-tag is
  held back so the `<` character does not leak into the streamed JSON
  string.
- Tests in test_hunyuan_detector.py now accumulate deltas by
  tool_index (qwen3_coder-style) before asserting, and add coverage
  for typed arg coercion, partial-end-tag hold-back, two-phase
  ordering, and reference HYV3 output patterns.
feat(function_call): streaming HYV3 tool parser
fix Hy3_preview model read rope_theta from config bug
@Qiaolin-Yu Qiaolin-Yu self-assigned this Apr 23, 2026
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@andyluo7
Copy link
Copy Markdown
Contributor

Hi @stevenkuang-tencent — friendly heads-up from AMD validation:

I tested this PR end-to-end on both MI300X (gfx942) and MI355X (gfx950) at TP=8, and hit a CUDA-graph crash that blocks the recommended config from the model card on AMD.

Eager mode (--disable-cuda-graph) works ✅ — confirms the model code itself is correct on AMD.
CUDA-graph mode crashes ❌ with HSA_STATUS_ERROR_EXCEPTION: 0x1016 on the first decode batch, root cause hipErrorStreamCaptureInvalidated.

I bisected this to AITER's custom all-reduce (AiterCustomAllreduce, default on HIP). Filed a clean writeup with reproducer + workaround at #23580.

Workaround (single env var, preserves all other AITER fast paths):

SGLANG_USE_AITER_AR=0 python3 -m sglang.launch_server ...

With that, full validation matrix on both AMD GPUs:

Hardware Workload tok/s
MI300X TP=8 single long (512 tok) 34.9
MI300X TP=8 c=8, 32 reqs 250.7
MI355X TP=8 single long (512 tok) 39.6
MI355X TP=8 c=8, 32 reqs 295.7

(For reference, vLLM PR vllm-project/vllm#40681 — Hy3 support on vLLM — works on AMD without this issue, achieving 58 tok/s single-stream and 314 tok/s @ c=8 on MI300X. The vLLM piecewise CUDA-graph mode tolerates the AITER all-reduce. SGLang's monolithic capture doesn't.)

The issue is not in this PR specifically — it's in SGLang's interaction with AITER's all-reduce — but Hy3 happens to be the trigger. Wanted to flag it before AMD users hit it.

Happy to help validate any future commits on AMD if useful.

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

2 similar comments
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_qwen3_next_models_mtp.py test_unified_radix_cache_kl.py

@github-actions
Copy link
Copy Markdown
Contributor

4-gpu-h100 (2 tests): View workflow run

cd test/ && python3 registered/4-gpu-models/test_qwen3_next_models_mtp.py
cd test/ && python3 registered/radix_cache/test_unified_radix_cache_kl.py

@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

1 similar comment
@JustinTong0323
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

we may do some refactor on the group topk kernel later, but should be good to merge it now

@Qiaolin-Yu Qiaolin-Yu merged commit 6d03861 into sgl-project:main Apr 24, 2026
583 of 691 checks passed
Kangyan-Zhou added a commit that referenced this pull request Apr 26, 2026
…n test

The Phase-3 renormalize block in `grouped_topk_single_group_kernel` called
`warp_sum_f32` (which uses `__shfl_xor_sync(0xffffffff, ...)`) from inside
`if (lane_id < topk)`. With `topk` < 32 (e.g. nemotron-3-nano: topk=6), only
lanes 0..topk-1 reached the intrinsic, but the mask 0xffffffff named all 32
lanes. CUDA spec: every lane named in the mask must execute the intrinsic
at the same site, otherwise the result is undefined.

Empirically the UB returned values from the absent lanes' registers,
producing wrong renormalized weights — 2 of 6 weights per token were
unnormalized (~1.5x too large). The wrong values were tolerated in eager
inference, but under piecewise CUDA graph replay they cascaded into a
downstream OOB that surfaced as IMA at `piecewise_cuda_graph_runner.py:794`
on `TestNvidiaNemotron3Nano30BFP8.test_lm_eval`.

Fix: move the warp_sum out of the divergent `if`, have all 32 lanes
participate, with inactive lanes contributing the additive identity (0).
Output writes remain gated by `if (lane_id < topk)`.

Validated:
- Unit sweep across E in {16..512}, K in {1..8}, N in {1..128}: matches
  reference biased_grouped_topk_impl with max diff < 1e-7.
- 2x H200 e2e: TestNvidiaNemotron3Nano30BFP8.test_lm_eval passes
  (gsm8k strict=0.839, flexible=0.542, both within rtol=0.08).
- Buggy kernel + eager (no graphs) also passes — confirming the kernel
  itself doesn't fault, only the cascade-under-graph-replay does.

This is the surgical alternative to #23758, which reverts the entire
#23533 (~4000 lines). The model code, tool/reasoning parsers, and tuned
MoE configs from #23533 are not part of the bug.

Also re-enables `test_nvidia_nemotron_3_nano` (the stop-gap disable was
added in #23720 when this IMA started showing up).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wen-xuan-Xu added a commit to Wen-xuan-Xu/sglang that referenced this pull request Apr 29, 2026
After sgl-project#23019 moved the MoE config loader and the configs/ tree from
`fused_moe_triton/` to `moe_runner/triton_utils/`, two later PRs
unknowingly added 33 tuned-config JSONs to the OLD path:

- sgl-project#22791 (LFM2)        — 24 files (E=32/64, H100/B200/MI325X)
- sgl-project#23533 (Hy3 preview) —  9 files (E=192,N=192 incl. _down,
                                    H20/H20-3e/B200)

The runtime loader anchors its search via
os.path.dirname(os.path.realpath(__file__)) of the loader file
(now in moe_runner/triton_utils/), so configs in the old
directory were never read — runtime fell back to
get_default_config().

The configs themselves were properly tuned and benchmarked at
submission time via the in-process override_config() path used
by the tuning script — that is why the PR authors observed real
speedup. The bug is purely a wrong filesystem location.

Root cause: the tuning README still pointed contributors to the
old path. This PR moves the misplaced configs into the
runtime-loaded location and fixes the README.

Changes:
  * R100 git-mv 33 JSONs into moe_runner/triton_utils/configs/{triton_3_5_1,triton_3_6_0}/
  * Update benchmark/kernels/fused_moe_triton/README.md path

No content changes. No code changes.

References: sgl-project#23019 sgl-project#22791 sgl-project#23533
Wen-xuan-Xu added a commit to Wen-xuan-Xu/sglang that referenced this pull request Apr 29, 2026
After sgl-project#23019 moved the MoE config loader and the configs/ tree from
`fused_moe_triton/` to `moe_runner/triton_utils/`, two later PRs
unknowingly added 33 tuned-config JSONs to the OLD path:

- sgl-project#22791 (LFM2)        — 24 files (E=32/64, H100/B200/MI325X)
- sgl-project#23533 (Hy3 preview) —  9 files (E=192,N=192 incl. _down,
                                    H20/H20-3e/B200)

The runtime loader anchors its search via
os.path.dirname(os.path.realpath(__file__)) of the loader file
(now in moe_runner/triton_utils/), so configs in the old
directory were never read — runtime fell back to
get_default_config().

The configs themselves were properly tuned and benchmarked at
submission time via the in-process override_config() path used
by the tuning script — that is why the PR authors observed real
speedup. The bug is purely a wrong filesystem location.

Root cause: the tuning README still pointed contributors to the
old path. This PR moves the misplaced configs into the
runtime-loaded location and fixes the README.

Changes:
  * R100 git-mv 33 JSONs into moe_runner/triton_utils/configs/{triton_3_5_1,triton_3_6_0}/
  * Update benchmark/kernels/fused_moe_triton/README.md path

No content changes. No code changes.

References: sgl-project#23019 sgl-project#22791 sgl-project#23533
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation jit-kernel run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants