Skip to content

Revert #23533 (Hy3 preview) + re-enable test_nvidia_nemotron_3_nano#23758

Closed
alisonshao wants to merge 2 commits intomainfrom
revert-hy3-preview-reenable-nemotron
Closed

Revert #23533 (Hy3 preview) + re-enable test_nvidia_nemotron_3_nano#23758
alisonshao wants to merge 2 commits intomainfrom
revert-hy3-preview-reenable-nemotron

Conversation

@alisonshao
Copy link
Copy Markdown
Collaborator

@alisonshao alisonshao commented Apr 26, 2026

Reverts #23533 and re-enables test_nvidia_nemotron_3_nano, which #23720 disabled as a stop-gap when scheduled pr-test started failing.

Bisected the failure (Fatal Python error: Aborted from piecewise_cuda_graph_runner.py:794 during FP8 nemotron decode, surfaced as Triton Error [CUDA]: an illegal memory access in _static_quant_fp8) on a 2x H200 against TestNvidiaNemotron3Nano30BFP8.test_lm_eval. First-bad commit is 6d0386147 (#23533); the parent (6344b546c) ran gsm8k=0.850 cleanly. Failure example: https://github.com/sgl-project/sglang/actions/runs/24936337295/job/73022450777.

#23533 added a new grouped_topk_single_group_kernel and wired it in for any single-group MoE with ≤512 experts and topk≤8 (python/sglang/srt/layers/moe/topk.py). Nemotron-3-Nano-A3B falls into that gate. The kernel corrupts CUDA state, and the next sync point — _static_quant_fp8 in the FP8 path — surfaces the illegal access. The reason #23533's own CI was green is that its branch predated #22218 (Breakable Piecewise Cuda Graph) — each PR works alone, the combination on main does not.

Conflicts during revert:

Reland #23533 once the new kernel is audited against the breakable-PCG runner.

Test plan

  • stage-b-test-2-gpu-large (2) runs the re-enabled test and reports gsm8k≈0.85.
  • Other partitions stay green; no Hy3 imports remain (grepped hunyuan_v3, hunyuan_v3_nextn, grouped_topk, hunyuan_detector).

Removes the `disabled="Temporarily disabled; failing on main."` flag
added in #23720. The underlying crash is gone now that #23533 has been
reverted in the previous commit, so the test should once again run as
part of stage-b-test-2-gpu-large.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the implementation and support for the Hunyuan-V3 (HYV3) model architecture, including its specialized MoE routing kernels, function call detectors, reasoning parsers, and model configurations. Additionally, it re-enables a test for the Nemotron-3-Nano model that was previously disabled. I have no feedback to provide as there are no review comments to assess.

@alisonshao
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_nvidia_nemotron_3_nano.py

@github-actions
Copy link
Copy Markdown
Contributor

2-gpu-h100 (1 test): View workflow run

cd test/ && python3 registered/models/test_nvidia_nemotron_3_nano.py

Kangyan-Zhou added a commit that referenced this pull request Apr 26, 2026
…n test

The Phase-3 renormalize block in `grouped_topk_single_group_kernel` called
`warp_sum_f32` (which uses `__shfl_xor_sync(0xffffffff, ...)`) from inside
`if (lane_id < topk)`. With `topk` < 32 (e.g. nemotron-3-nano: topk=6), only
lanes 0..topk-1 reached the intrinsic, but the mask 0xffffffff named all 32
lanes. CUDA spec: every lane named in the mask must execute the intrinsic
at the same site, otherwise the result is undefined.

Empirically the UB returned values from the absent lanes' registers,
producing wrong renormalized weights — 2 of 6 weights per token were
unnormalized (~1.5x too large). The wrong values were tolerated in eager
inference, but under piecewise CUDA graph replay they cascaded into a
downstream OOB that surfaced as IMA at `piecewise_cuda_graph_runner.py:794`
on `TestNvidiaNemotron3Nano30BFP8.test_lm_eval`.

Fix: move the warp_sum out of the divergent `if`, have all 32 lanes
participate, with inactive lanes contributing the additive identity (0).
Output writes remain gated by `if (lane_id < topk)`.

Validated:
- Unit sweep across E in {16..512}, K in {1..8}, N in {1..128}: matches
  reference biased_grouped_topk_impl with max diff < 1e-7.
- 2x H200 e2e: TestNvidiaNemotron3Nano30BFP8.test_lm_eval passes
  (gsm8k strict=0.839, flexible=0.542, both within rtol=0.08).
- Buggy kernel + eager (no graphs) also passes — confirming the kernel
  itself doesn't fault, only the cascade-under-graph-replay does.

This is the surgical alternative to #23758, which reverts the entire
#23533 (~4000 lines). The model code, tool/reasoning parsers, and tuned
MoE configs from #23533 are not part of the bug.

Also re-enables `test_nvidia_nemotron_3_nano` (the stop-gap disable was
added in #23720 when this IMA started showing up).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants