[Kernel Slimming] Migrate AWQ marlin repack kernel to JIT#18949
[Kernel Slimming] Migrate AWQ marlin repack kernel to JIT#18949BBuf merged 6 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @celve, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
There was a problem hiding this comment.
Code Review
This pull request successfully migrates the AWQ marlin repack kernel to a JIT-compiled version, which is a great step for kernel slimming. The implementation is solid, including the new Python wrapper, CUDA kernel, tests, and benchmarks. I've identified a few minor areas for improvement regarding code duplication in test/benchmark files, a magic number that could be a constant, and some unreachable code in the CUDA host wrapper. Addressing these points will enhance the code's maintainability and clarity. Overall, this is a well-executed migration.
| size_n: int, | ||
| num_bits: int, | ||
| ) -> torch.Tensor: | ||
| tile_size = 16 |
| def awq_pack(q_w, num_bits, size_k, size_n): | ||
| if num_bits == 4: | ||
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) | ||
| elif num_bits == 8: | ||
| interleave = np.array([0, 2, 1, 3]) | ||
| else: | ||
| raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) | ||
|
|
||
| q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() | ||
| q_w = q_w.reshape((-1, size_n)).contiguous() | ||
| return pack_cols(q_w, num_bits, size_k, size_n) |
There was a problem hiding this comment.
This awq_pack function is duplicated in python/sglang/jit_kernel/tests/test_awq_marlin_repack.py. To avoid code duplication and improve maintainability, consider moving this function to a shared test utility module.
Additionally, it's better to raise a more specific exception like ValueError instead of a generic Exception.
| def awq_pack(q_w, num_bits, size_k, size_n): | |
| if num_bits == 4: | |
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) | |
| elif num_bits == 8: | |
| interleave = np.array([0, 2, 1, 3]) | |
| else: | |
| raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) | |
| q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() | |
| q_w = q_w.reshape((-1, size_n)).contiguous() | |
| return pack_cols(q_w, num_bits, size_k, size_n) | |
| def awq_pack(q_w, num_bits, size_k, size_n): | |
| if num_bits == 4: | |
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) | |
| elif num_bits == 8: | |
| interleave = np.array([0, 2, 1, 3]) | |
| else: | |
| raise ValueError(f"num_bits must be 4 or 8, got {num_bits}") | |
| q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() | |
| q_w = q_w.reshape((-1, size_n)).contiguous() | |
| return pack_cols(q_w, num_bits, size_k, size_n) |
| } else { | ||
| RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); | ||
| } |
| def awq_pack( | ||
| q_w: torch.Tensor, | ||
| num_bits: int, | ||
| size_k: int, | ||
| size_n: int, | ||
| ): | ||
| assert q_w.shape == (size_k, size_n) | ||
|
|
||
| if num_bits == 4: | ||
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) | ||
| elif num_bits == 8: | ||
| interleave = np.array([0, 2, 1, 3]) | ||
| else: | ||
| raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) | ||
|
|
||
| q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() | ||
| q_w = q_w.reshape((-1, size_n)).contiguous() | ||
|
|
||
| return pack_cols(q_w, num_bits, size_k, size_n) |
There was a problem hiding this comment.
This awq_pack function is duplicated in python/sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py. To improve maintainability, it would be best to extract it into a shared test utility file.
Also, it's good practice to raise a more specific ValueError instead of a generic Exception.
| def awq_pack( | |
| q_w: torch.Tensor, | |
| num_bits: int, | |
| size_k: int, | |
| size_n: int, | |
| ): | |
| assert q_w.shape == (size_k, size_n) | |
| if num_bits == 4: | |
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) | |
| elif num_bits == 8: | |
| interleave = np.array([0, 2, 1, 3]) | |
| else: | |
| raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) | |
| q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() | |
| q_w = q_w.reshape((-1, size_n)).contiguous() | |
| return pack_cols(q_w, num_bits, size_k, size_n) | |
| def awq_pack( | |
| q_w: torch.Tensor, | |
| num_bits: int, | |
| size_k: int, | |
| size_n: int, | |
| ): | |
| assert q_w.shape == (size_k, size_n) | |
| if num_bits == 4: | |
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) | |
| elif num_bits == 8: | |
| interleave = np.array([0, 2, 1, 3]) | |
| else: | |
| raise ValueError(f"num_bits must be 4 or 8, got {num_bits}") | |
| q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() | |
| q_w = q_w.reshape((-1, size_n)).contiguous() | |
| return pack_cols(q_w, num_bits, size_k, size_n) |
There was a problem hiding this comment.
Pull request overview
This PR migrates the AWQ marlin repack kernel from ahead-of-time (AOT) compilation in sgl-kernel to just-in-time (JIT) compilation, as part of the kernel slimming initiative to reduce the sgl-kernel wheel size by approximately 97.5MB.
Changes:
- Moved awq_marlin_repack kernel implementation from sgl-kernel AOT to JIT compilation
- Added comprehensive tests comparing JIT vs AOT implementations for correctness
- Added benchmarking to compare JIT vs AOT performance
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| python/sglang/srt/layers/quantization/awq.py | Updated import to use JIT-compiled awq_marlin_repack instead of sgl_kernel AOT version |
| python/sglang/jit_kernel/awq_marlin_repack.py | Python wrapper for JIT kernel with output tensor allocation |
| python/sglang/jit_kernel/csrc/gemm/marlin/awq_marlin_repack.cuh | JIT-compiled CUDA kernel ported from AOT implementation |
| python/sglang/jit_kernel/tests/test_awq_marlin_repack.py | Comprehensive tests for correctness (JIT vs AOT and expected behavior) |
| python/sglang/jit_kernel/benchmark/bench_awq_marlin_repack.py | Performance benchmarking comparing JIT vs AOT implementations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -60,7 +60,9 @@ | |||
| import torch_npu | |||
|
|
|||
| if _is_cuda: | |||
| from sgl_kernel import awq_dequantize, awq_marlin_moe_repack, awq_marlin_repack | |||
| from sgl_kernel import awq_dequantize, awq_marlin_moe_repack | |||
There was a problem hiding this comment.
todo: we should also remove awq_dequantize and awq_marlin_moe_repack to jit_kernel
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
|
awq marlin moe repack benchmark: Serve With AOT: |
|
awq dequant benchmark: Serve |
|
/rerun-failed-ci |
…o xverse_moe * 'xverse_moe' of https://github.com/xiaobaicxy/sglang: (275 commits) fix: add missing blank line after docstring in serving_transcription.py (sgl-project#19206) Whisper model support & `/v1/audio/transcriptions` endpoint & benchmark (sgl-project#16983) fix: patch docker image fixes (sgl-project#19100) [PD-Disagg] Unify prefill info data transition flow, all with `PrefillServerInfo` (sgl-project#19195) [CI] Tiny enhance the dp attention load blance benchmark (sgl-project#19194) add new ci user (sgl-project#19133) [CI] fix the teardown output of disaggregation test (sgl-project#19193) [PD-Disagg] Support query dp rank from bootstrap server. (sgl-project#19168) [Kernel Slimming] Migrate AWQ marlin repack kernel to JIT (sgl-project#18949) [Diffusion] Match rotary_embedding module name style (sgl-project#19179) [Refactor] Split rotary_embedding.py into a modular package (sgl-project#19144) [NPU] bump sgl-kernel-npu to 2026.02.01.post2 (sgl-project#19178) Use single mma warp group for short q_len in FA to optimize decoding performance (sgl-project#18985) Reorganize topk logic to clean up code and expose logical experts (sgl-project#16945) [ROCm] Use unreg path for custom all-reduce during CUDA graph capture (sgl-project#19162) [diffusion] feat: detect Flux2 custom VAE path from component_paths (sgl-project#19170) [AMD] ENV flags tuning and cleanup (sgl-project#19176) Fix bench_one_batch_server by moving the print statements (sgl-project#19175) Update rocm7.2 Dockerfile to install amdsmi for QuickReduce Initialization (sgl-project#19091) Revert "Refactor graph input buffers (sgl-project#18991)" (sgl-project#19173) ...
…t#18949) Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
…t#18949) Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
…t#18949) Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Motivation
See #17865
Modifications
New files:
Modified files:
Accuracy Tests
Pass all tests defined in python/sglang/jit_kernel/tests/test_awq_marlin_repack.py
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci