Skip to content

perf: optimize PCG inductor path for FP8 models (redo of #21734)#23227

Merged
ch-wan merged 2 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/fp8-inductor-fusion-v2
Apr 27, 2026
Merged

perf: optimize PCG inductor path for FP8 models (redo of #21734)#23227
ch-wan merged 2 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/fp8-inductor-fusion-v2

Conversation

@jasperjiaguo
Copy link
Copy Markdown
Contributor

Summary

Re-applies the inductor-fusion optimizations from #21734 that were reverted in #23159, with the AMD crash fixed.

cc @HaiShaw

Changes vs. #21734

  • models/utils.py: use reshape(*q.shape[:-1], -1, head_dim) instead of view(...). reshape is safe on non-contiguous QKV-split tensors (the AMD crash root cause), while still giving inductor the same multi-dim shape information as view for fusion with surrounding ops.
  • fp8_utils.py: unchanged from perf: optimize PCG inductor path for FP8 models #21734 (FP8 per-tensor activation quant inductor-fusion path).

Why #21734's view() crashed on AMD

  • Under PCG capture, q/k are stride-trick views from the QKV split and are frequently non-contiguous.
  • view() on non-contiguous input yields a tensor whose strides disagree with what q_norm / k_norm kernels expect, leading to writes into unmapped pages (Memory access fault by GPU) on AMD.
  • reshape(*q.shape[:-1], -1, head_dim) preserves the same multi-dim shape for inductor but copies when strides don't allow a view, avoiding the fault.

Intention (same as #21734)

Preserve the multi-dim shape for inductor fusion while skipping the opaque fused_inplace_qknorm under the inductor compiler so inductor can fuse QK norm with RMSNorm / quant / residual add.

Test plan

🤖 Generated with Claude Code

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes FP8 linear layers and QK normalization for the Torch Inductor compiler by utilizing standard PyTorch operations that facilitate operator fusion. Specifically, it introduces a path for static per-tensor activation scales to use native ops and updates apply_qk_norm to preserve tensor dimensions during reshaping, which provides better shape information for the compiler. Feedback suggests refactoring duplicated normalization logic in apply_qk_norm into a helper function to improve code clarity and maintainability.

Comment thread python/sglang/srt/models/utils.py
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/fp8-inductor-fusion-v2 branch 7 times, most recently from afc79c8 to eeec2ff Compare April 21, 2026 22:23
jasperjiaguo and others added 2 commits April 23, 2026 09:43
…21734)

Re-applies the inductor-fusion optimizations from sgl-project#21734 that were
reverted in sgl-project#23159, with the AMD crash fixed.

Changes vs. sgl-project#21734:
- models/utils.py: use `reshape(*q.shape[:-1], -1, head_dim)` instead
  of `view(...)`. reshape is safe on non-contiguous QKV-split tensors
  (the AMD crash root cause), while still giving inductor the same
  multi-dim shape information as view for fusion with surrounding ops.
- fp8_utils.py: unchanged from sgl-project#21734 (FP8 per-tensor activation quant
  inductor-fusion path).

Why sgl-project#21734's view() crashed on AMD:
- Under PCG capture, q/k are stride-trick views from the QKV split
  and are frequently non-contiguous.
- `view()` on non-contiguous input yields a tensor whose strides
  disagree with what q_norm/k_norm kernels expect, leading to writes
  into unmapped pages (`Memory access fault by GPU`) on AMD.
- `reshape(*q.shape[:-1], -1, head_dim)` preserves the same multi-dim
  shape for inductor but copies when strides don't allow a view,
  avoiding the fault.

Intention matches sgl-project#21734: preserve the multi-dim shape for inductor
fusion while skipping the opaque fused_inplace_qknorm under the
inductor compiler so inductor can fuse QK norm with RMSNorm / quant
/ residual add.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
AMD CI hit the same Memory access fault on the jiaguo/fp8-inductor-fusion-v2
branch (this PR) as on the original sgl-project#21734, because reshape() returns a
stride-preserving view when the new shape is compatible with the input
strides -- which (*q.shape[:-1], -1, head_dim) is for a QKV-split tensor.
The resulting non-contiguous tensor is then passed to ROCm's RMSNorm
kernel, which assumes contiguous inputs and faults.

Fix: keep the stride-preserving view only on the exact path that needs
inductor fusion (CUDA + piecewise_cuda_graph_compiler == "inductor"). On
ROCm and on CUDA with the eager PCG fallback, revert to the flat 2D
reshape(-1, head_dim), which is guaranteed to copy on a stride-tricked
input (the flat shape is not stride-viewable when the inner stride
spans H_total*D rather than H_q*D).

Also extract the reshape into a single `_reshape_for_qk_norm` helper,
replacing four inline copies in `apply_qk_norm`'s alt-stream and
straight-line branches.

Notes on semantics:
  - Original upstream `reshape(-1, head_dim)` happened to force a copy
    for the same stride reason, but only because the flat shape was
    un-viewable -- not by any explicit contiguity guarantee. We rely on
    the same property here for the ROCm path.
  - On the CUDA+inductor path, the stride-preserving view is what lets
    inductor fuse this reshape with the surrounding RMSNorm and FP8
    quant into a single triton kernel -- the original motivation of
    sgl-project#21734. Forcing .contiguous() would lose that fusion.
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/fp8-inductor-fusion-v2 branch from eeec2ff to bd9ea0d Compare April 23, 2026 16:43
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ch-wan ch-wan merged commit bead2e3 into sgl-project:main Apr 27, 2026
754 of 853 checks passed
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…21734) (sgl-project#23227)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants