Skip to content

[NPU] support GPTQ quantization on npu#15203

Merged
iforgetmyname merged 5 commits intosgl-project:mainfrom
22dimensions:gptq_npu
Jan 29, 2026
Merged

[NPU] support GPTQ quantization on npu#15203
iforgetmyname merged 5 commits intosgl-project:mainfrom
22dimensions:gptq_npu

Conversation

@22dimensions
Copy link
Copy Markdown
Contributor

@22dimensions 22dimensions commented Dec 15, 2025

Motivation

This PR follows #15202 and Roadmap of NPU support #13664.

Modifications

  1. Add GPTQLinearAscendMethod class
    • unpack qweight qzeros to supported dtype
    • use torch_npu.npu_weight_quant_batchmatmul kernel in linear's forward method
  2. Add test case for GPTQ int8
  3. Fix npu prefetch error on Qwen3-GPTQ series model.
    • root cause: GPTQ uses qweight to represent the quantized weight not weight
    • before: AttributeError: mlp.gate_up_proj object has no attribute weight
    • now: error fixed
  4. Support GPTQv2 checkpoint format of GPTQModel
    • see: [Feature]: Support GPTQv2 checkpoint format of GPTQModel vllm-project/vllm#26343
    • The only difference between GPTQv1 and GPTQv2 format, is how they stores the zero points. Specifically, GPTQv1 format subtracts 1 from zero points, and GPTQv2 format does not.
    • GPU: not supported yet, need to adapt gpu quantization kernel
    • NPU: handle zero point in process_weights_after_loading
      • GPTQv1: unpacked zero point += 1
      • GPTQv2: remain the original unpacked zero point

Accuracy Tests And Benchmarking and Profiling

Same model but with different data types to see the difference:

  • original float16 model: Qwen/Qwen3-1.7B
python3 -m sglang.launch_server --model-path Qwen/Qwen3-1.7B  --device npu --attention-backend ascend --port 30000

python ./python/sglang/test/few_shot_gsm8k.py

Accuracy: 0.710
Invalid: 0.000
Latency: 33.498 s
Output throughput: 732.379 token/s
  • 8bit quantized model: Qwen/Qwen3-1.7B-GPTQ-Int8
python3 -m sglang.launch_server --model-path Qwen/Qwen3-1.7B-GPTQ-Int8  --device npu --attention-backend ascend --port 30000 --quantization gptq

python ./python/sglang/test/few_shot_gsm8k.py

Accuracy: 0.690
Invalid: 0.005
Latency: 19.713 s
Output throughput: 1263.426 token/s
  • 4bit quantized model: JunHowie/Qwen3-1.7B-GPTQ-Int4
python3 -m sglang.launch_server --model-path JunHowie/Qwen3-1.7B-GPTQ-Int4  --device npu --attention-backend ascend --port 30000 --quantization gptq

python ./python/sglang/test/few_shot_gsm8k.py

Accuracy: 0.190
Invalid: 0.005
Latency: 26.373 s
Output throughput: 2087.469 token/s

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @22dimensions, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances SGLang's capabilities by integrating GPTQ quantization support for Ascend NPU hardware. It introduces a specialized linear method (GPTQLinearAscendMethod) that leverages NPU-specific kernels for efficient quantized operations and includes necessary weight unpacking logic. Additionally, it resolves a critical cache handling issue for Qwen3 models when quantized on NPU, ensuring broader compatibility and improved performance.

Highlights

  • NPU GPTQ Support: Introduced the GPTQLinearAscendMethod class to enable GPTQ quantization specifically for Ascend NPU devices, expanding hardware compatibility.
  • Weight Unpacking Logic: Implemented a new unpack_from_int32 function to correctly unpack quantized weights from int32 format to their original bit representation for NPU operations.
  • NPU Kernel Integration: Integrated the torch_npu.npu_weight_quant_batchmatmul kernel for efficient quantized matrix multiplication, leveraging NPU hardware capabilities.
  • Qwen3 Cache Fix: Addressed and fixed a cache error encountered in Qwen3 models when running with quantization on NPU, improving model stability.
  • Expanded Data Type Support: Enabled torch.bfloat16 as a supported activation data type for GPTQ on NPU, offering more flexibility in mixed-precision training and inference.
  • New Test Case: Added a dedicated test case (test_ascend_gptq.py) to validate GPTQ Int8 accuracy on Ascend NPU using the Qwen3-1.7B-GPTQ-Int8 model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces GPTQ quantization support for Ascend NPUs, which is a great addition for improving performance on this hardware. The changes include a new GPTQLinearAscendMethod, NPU-specific kernels, and a test case. My review focuses on correctness and maintainability. I've found a couple of critical issues, including a syntax error and a bug in tensor reshaping that could lead to runtime errors. I've also included some suggestions to improve code clarity and test portability. Overall, this is a solid contribution, and after addressing the critical issues, it should be in good shape.

Comment thread python/sglang/srt/layers/linear.py Outdated
Comment thread python/sglang/srt/layers/quantization/gptq.py Outdated
Comment thread python/sglang/srt/layers/quantization/gptq.py
Comment thread python/sglang/srt/layers/quantization/gptq.py
Comment thread test/srt/ascend/test_ascend_gptq.py
@ping1jing2 ping1jing2 self-assigned this Dec 15, 2025
@22dimensions 22dimensions force-pushed the gptq_npu branch 2 times, most recently from a99e6c6 to ff71232 Compare December 16, 2025 02:21
@OrangeRedeng
Copy link
Copy Markdown
Contributor

Hi! The accuracy of the int4 model looks very low, have you tested it on the GPU, does it show the same accuracy?

@22dimensions
Copy link
Copy Markdown
Contributor Author

Hi! The accuracy of the int4 model looks very low, have you tested it on the GPU, does it show the same accuracy?

No, I will paste the gpu precision later, maybe 4bit quantization is little too low for Qwen3-1.7B.

@TamirBaydasov
Copy link
Copy Markdown
Contributor

Could you pls provide more background for "Fix qwen3 cache error in quantization case on npu." ? For example, showing status before and after fix would be sufficient.

@OrangeRedeng
Copy link
Copy Markdown
Contributor

Hi! The accuracy of the int4 model looks very low, have you tested it on the GPU, does it show the same accuracy?

No, I will paste the gpu precision later, maybe 4bit quantization is little too low for Qwen3-1.7B.

It might be worth doing a test on another dataset or model, from my point of view the accuracy looks strange. It should be lower than int8, but is it that low? Especially when activations are not quantized.

@22dimensions
Copy link
Copy Markdown
Contributor Author

22dimensions commented Dec 16, 2025

Could you pls provide more background for "Fix qwen3 cache error in quantization case on npu." ? For example, showing status before and after fix would be sufficient.

Sorry for the unclear description. And I found that this pr #14884 fix the same issue as i encountered. I think I can update my branch after it is merged.

@22dimensions
Copy link
Copy Markdown
Contributor Author

22dimensions commented Dec 16, 2025

Hi! The accuracy of the int4 model looks very low, have you tested it on the GPU, does it show the same accuracy?

No, I will paste the gpu precision later, maybe 4bit quantization is little too low for Qwen3-1.7B.

It might be worth doing a test on another dataset or model, from my point of view the accuracy looks strange. It should be lower than int8, but is it that low? Especially when activations are not quantized.

I just test GLM-4-9B-0414 series model, here is the result:

python3 -m sglang.launch_server --model-path ZhipuAI/GLM-4-9B-0414  --device npu --attention-backend ascend --port 30000 --mem-fraction-static 0.8
python ./python/sglang/test/few_shot_gsm8k.py

Accuracy: 0.790
Invalid: 0.000
Latency: 40.490 s
Output throughput: 543.412 token/s

SGLANG_USE_MODELSCOPE=true python3 -m sglang.launch_server --model-path  tclf90/glm-4-9b-0414-gptq-int4  --device npu --attention-backend ascend --port 30000 --mem-fraction-static 0.8 --quantization gptq
python ./python/sglang/test/few_shot_gsm8k.py

Accuracy: 0.750
Invalid: 0.000
Latency: 27.975 s
Output throughput: 1000.675 token/s

this data looks resonable

@22dimensions
Copy link
Copy Markdown
Contributor Author

cc: @ping1jing2

@22dimensions
Copy link
Copy Markdown
Contributor Author

cc: @iforgetmyname

@ping1jing2
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@ping1jing2
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
@iforgetmyname iforgetmyname self-assigned this Jan 28, 2026
@iforgetmyname
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@iforgetmyname iforgetmyname merged commit 7b79326 into sgl-project:main Jan 29, 2026
372 of 396 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Jan 30, 2026
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Chen-0210 pushed a commit to Chen-0210/sglang that referenced this pull request Jan 30, 2026
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants