[AMD] Add fused GemmaRMSNorm forward_hip to use aiter/vllm kernels for qwen3.5#21188
Conversation
Previously GemmaRMSNorm re-dispatched HIP to forward_native, bypassing fused kernels. Add a dedicated forward_hip that routes through aiter or vllm fused_add_rms_norm/rms_norm, matching the existing CUDA path logic with the +1 weight offset that Gemma requires.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
HaiShaw
left a comment
There was a problem hiding this comment.
Approve for now. Let's remove/clone vllm dependencies in a different PR.
|
/tag-and-rerun-ci |
Confirmed -- forward_hip handles all three cases:
No callers check which method _forward_method points to; they all go through MultiPlatformOp.forward() which just calls self._forward_method(). The fallback path is identical to what the old _is_hip override did, just reached via forward_hip instead of being wired directly in init. |
co-author: @zhentaocc
Motivation
Previously
GemmaRMSNormre-dispatched HIP toforward_native, bypassing fused kernels. This adds a dedicatedforward_hipthat routes through aiter or vllmfused_add_rms_norm/rms_norm, matching the existing CUDA path logic with the +1 weight offset that Gemma requires.Modifications
__init__override that forced_forward_method = forward_nativeon HIP.forward_hip()method toGemmaRMSNormthat:weight + 1.0offset._use_aiteris set.forward_nativeif neither is available.Accuracy Tests
Model: Qwen3.5-397B-A17B-FP8, Hardware: 8x MI355X, Image:
rocm/sgl-dev:v0.5.9-rocm720-mi35x-20260318GSM8K (5-shot, 2000 questions, parallel=1000):
Server launch command:
SGLANG_USE_AITER=1 python3 -m sglang.launch_server \ --model-path /data/Qwen3.5-397B-A17B-FP8/ \ --tp 8 \ --attention-backend aiter \ --trust-remote-code \ --model-loader-extra-config '{"enable_multithread_load": true}' \ --watchdog-timeout 1200 \ --mem-fraction-static 0.8 \ --host 0.0.0.0 --port 9000Accuracy test command:
Benchmarking and Profiling
Workload:
sglang.bench_serving --dataset-name random --random-input 8192 --random-output 1024 --random-range-ratio 1.0Benchmark command:
Concurrency=1 comparison (8 prompts):
Latency & Throughput
TTFT & ITL
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci