Skip to content

[AMD][Kimi K2.5 Day 0] ROCm: route W4A16 MoE to Triton and fix packed-weight loading#17863

Merged
HaiShaw merged 3 commits intosgl-project:mainfrom
jhinpan:k2.5-support
Jan 28, 2026
Merged

[AMD][Kimi K2.5 Day 0] ROCm: route W4A16 MoE to Triton and fix packed-weight loading#17863
HaiShaw merged 3 commits intosgl-project:mainfrom
jhinpan:k2.5-support

Conversation

@jhinpan
Copy link
Copy Markdown
Collaborator

@jhinpan jhinpan commented Jan 28, 2026

Motivation

As issue #17854
On ROCm, CompressedTensorsWNA16MoEMethod currently routes to Marlin kernels by default. Marlin is NVIDIA‑only, which breaks Kimi‑K2.5 (native INT4) on MI300X. This patch dispatches ROCm to Triton and fixes the weight‑loading transpose path to avoid shape mismatches.

Modifications

  • python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
    • Add CompressedTensorsWNA16TritonMoEMethod and convert weights/scales to Triton layout.
    • Dispatch to Triton when is_hip() is true.
  • python/sglang/srt/layers/moe/fused_moe_triton/layer.py
    • Include CompressedTensorsWNA16TritonMoEMethod in packed‑weight transpose checks.

Accuracy Tests

tested on gsm8k

image

Benchmarking and Profiling

Benchmark results here: notion.so/Kimi-K2-5-on-MI300X-2f5651cb22e580cb9395d6169ee59d66?pvs=73

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Jan 28, 2026

/tag-and-rerun-ci

@HaiShaw HaiShaw added the amd label Jan 28, 2026
Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This should for MI350/355 as well
  2. Can we add MoE triton tuning later?

@HaiShaw HaiShaw merged commit 1953efb into sgl-project:main Jan 28, 2026
167 of 189 checks passed
@jhinpan
Copy link
Copy Markdown
Collaborator Author

jhinpan commented Jan 28, 2026

  1. This should for MI350/355 as well
  2. Can we add MoE triton tuning later?
  1. Yes. I will test it this weekend when I get MI350 back.
  2. Sure. When we have the kernel optimization pipeline. We can also give it a try.

if getattr(layer, "is_triton_converted", False):
return

num_experts = layer.w13_weight_packed.shape[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this variable "num_experts" is used in this function block, does it still need to exist?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this nit! Will clean it today

@HaiShaw HaiShaw changed the title [AMD] ROCm: route W4A16 MoE to Triton and fix packed-weight loading [AMD][Kimi K2.5 Day 0] ROCm: route W4A16 MoE to Triton and fix packed-weight loading Jan 29, 2026
@jhinpan jhinpan deleted the k2.5-support branch January 29, 2026 20:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants