Skip to content

[NPU]TP Communications compression For Qwen3 models for NPU#20520

Merged
sglang-npu-bot merged 44 commits intosgl-project:mainfrom
egvenediktov:communications_compression
May 2, 2026
Merged

[NPU]TP Communications compression For Qwen3 models for NPU#20520
sglang-npu-bot merged 44 commits intosgl-project:mainfrom
egvenediktov:communications_compression

Conversation

@egvenediktov
Copy link
Copy Markdown
Contributor

@egvenediktov egvenediktov commented Mar 13, 2026

Motivation

Implemented INT8 TP communications compression on prefill for Qwen3 models.
Compression achieves average 5% performance improvement on prefill intense benchmarks (see Benchmarking and Profiling). Accuracy tests show no degradation on average (BoolQ, C-Eval, HellaSwag; see Accuracy Tests).

Description

TP introduce communications between devices after o_proj in attention and after down_proj in mlp. To reduce latency of communications we can quantize them before sending to other devices and dequantize right after communication is complete.

Below is profiling of TP communications for FP and compressed cases.

profiling comparison

Modifications

9 files changed:

  • python/sglang/srt/server_args.py (Added a server argument for communications compression, added checks for the argument)

  • python/sglang/srt/models/qwen2.py (Added ForwardBatch passing interface from MLP layer to down_proj linear)

  • python/sglang/srt/models/qwen3.py (Added ForwardBatch passing interface to MLP layer)

  • python/sglang/srt/layers/linear.py (Added logic for enabling quantization of communications)

  • python/sglang/srt/distributed/communication_op.py (Added interface to all_reduce operation to enable communications quantization)

  • python/sglang/srt/distributed/parallel_state.py (Added interface and logic of communications quantization for all_reduce operation)

  • python/sglang/srt/distributed/device_communicators/npu_communicator.py (Added implementation of all_reduce with quantized communications; quantization scheme - per-token, symmetric)

  • python/sglang/srt/layers/communicator.py (Added logic of communications quantization for residual add operation with all_reduce)

  • benchmark/boolq/bench_sglang.py (import fix)

Accuracy Tests

Server launch:

  • baseline
    ASCEND_RT_VISIBLE_DEVICES=3,6 python -m sglang.launch_server --device npu --attention-backend ascend --trust-remote-code --tp-size 2 --model-path /home/ckpt/Qwen3-32B/ --port 30088 --cuda-graph-max-bs 64 --max-prefill-tokens 32768 --chunked-prefill-size -1 --mem-fraction-static 0.8

  • with communications quantization
    ASCEND_RT_VISIBLE_DEVICES=3,6 python -m sglang.launch_server --device npu --attention-backend ascend --trust-remote-code --tp-size 2 --model-path /home/ckpt/Qwen3-32B/ --port 30088 --cuda-graph-max-bs 64 --max-prefill-tokens 32768 --chunked-prefill-size -1 --mem-fraction-static 0.8 --quantize-tp-communications

BoolQ:

Client launch
python ./benchmark/boolq/bench_sglang.py --port 30088 --train-data-path "path/to/train/data/file" --test-data-path "path/to/test/data/file" --parallel 64 --num-questions 3270

  • baseline
BoolQ baseline accuracy
  • with communications quantization
BoolQ accuracy with communications quantization

C-Eval:

Client launch
python ./benchmark/ceval/bench_sglang.py --port 30088 --data-path "path/to/data" --num-questions 1346

  • baseline
C-Eval baseline accuracy
  • with communications quantization
C-Eval accuracy with communications quantization

HellaSwag

Client launch
python ./benchmark/hellaswag/bench_sglang.py --port 30088 --data-path "path/to/data" --num-questions 50000

  • baseline
HellaSwag baseline accuracy
  • with communications quantization
HellaSwag accuracy with communications quantization

Benchmarking and Profiling

Acceleration of prefill intense benchmarks for Qwen3 32B on A2 server (acceleration = FP_time / INT8_time or acceleration = INT8_throughput / FP_throughput)

Server args:
ASCEND_RT_VISIBLE_DEVICES=3,6 python -m sglang.launch_server --device npu --attention-backend ascend --trust-remote-code --tp-size 2 --model-path /home/ckpt/Qwen3-32B/ --port 30088 --cuda-graph-max-bs 64 --max-prefill-tokens 32768 --chunked-prefill-size -1 --mem-fraction-static 0.8 --disable-radix-cache --quantize-tp-communications

Performance results

Checklist

@github-actions github-actions Bot added the npu label Mar 13, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization by enabling INT8 quantization for tensor parallel communications during the prefill stage of Qwen3 models when running on NPU devices. This enhancement aims to reduce communication bandwidth and latency, leading to improved performance without compromising model accuracy, as validated by extensive benchmarking and accuracy tests.

Highlights

  • Tensor Parallel Communications Compression: Enabled INT8 tensor parallel communications compression during the prefill phase specifically for Qwen3 models running on NPU devices.
  • New Server Argument and Validation: Introduced a new server argument --quantize-tp-communications to activate this feature, complete with validation checks to ensure correct usage (e.g., requiring tp_size > 1, Qwen3 models, and NPU devices).
  • Conditional Quantization Integration: Integrated conditional communication quantization logic into linear layers, core communication operations, and the forward passes of Qwen2 and Qwen3 models.
  • Performance and Accuracy Validation: Provided extensive accuracy test results across BoolQ, C-Eval, and HellaSwag benchmarks, demonstrating comparable accuracy, and presented benchmarking results showing performance acceleration for prefill-intensive workloads.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmark/boolq/bench_sglang.py
    • Updated an import path.
  • python/sglang/srt/distributed/communication_op.py
    • Modified the tensor_model_parallel_all_reduce function to accept a fp_comm flag, allowing for conditional floating-point communication.
  • python/sglang/srt/distributed/device_communicators/npu_communicator.py
    • Added NPU-specific dynamic quantization imports and implemented a quant_all_reduce method for INT8 all-reduce operations.
  • python/sglang/srt/distributed/parallel_state.py
    • Updated the all_reduce method within the ProcessGroup class to conditionally invoke the new quant_all_reduce method on NPU communicators based on the fp_comm flag.
  • python/sglang/srt/layers/communicator.py
    • Integrated logic to dynamically determine whether to use floating-point or quantized communications for tensor_model_parallel_all_reduce during residual add operations, based on the forward mode and server arguments.
  • python/sglang/srt/layers/linear.py
    • Imported server arguments and modified the forward method of linear layers to accept a forward_batch object, enabling conditional application of communication quantization based on the batch's forward mode and global server settings.
  • python/sglang/srt/models/qwen2.py
    • Modified the forward method of the Qwen2MLP class to pass the forward_batch object to the down_proj linear layer, propagating the necessary context for communication quantization.
  • python/sglang/srt/models/qwen3.py
    • Modified the forward method of the Qwen3DecoderLayer to pass the forward_batch object to the mlp layer, ensuring that communication quantization can be applied within the MLP.
  • python/sglang/srt/server_args.py
    • Introduced a new boolean server argument quantize_tp_communications and added comprehensive validation checks to ensure it is used correctly (e.g., with tp_size > 1, Qwen3 models, and NPU devices).
Activity
  • The author has provided a detailed description of the motivation, modifications, and extensive accuracy and performance benchmarks.
  • Accuracy tests were conducted on BoolQ, C-Eval, and HellaSwag datasets, showing minimal impact from communications quantization.
  • Benchmarking results indicate performance acceleration for prefill-intensive tasks with the new quantization feature.
  • A checklist is included, though all items are currently unchecked, suggesting ongoing work or pending final verification.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces INT8 communication compression for Tensor Parallelism on NPU devices, specifically for Qwen3 models. The implementation adds a new server argument --quantize-tp-communications and a quantized all_reduce operation for NPUs. This feature is designed to be active only during the prefill phase to optimize performance. The changes are well-contained within the distributed communication and model-specific layers. I've identified a potential issue in python/sglang/srt/layers/linear.py where quantization might be unsafely applied when the forward context is unavailable, and I have provided a suggested fix.

Comment thread python/sglang/srt/layers/linear.py Outdated
@ping1jing2 ping1jing2 changed the title TP Communications compression For Qwen3 models for NPU [NPU]TP Communications compression For Qwen3 models for NPU Mar 13, 2026
@ping1jing2 ping1jing2 self-assigned this Mar 13, 2026
@github-actions github-actions Bot added the quant LLM Quantization label Mar 18, 2026
@egvenediktov egvenediktov marked this pull request as ready for review March 19, 2026 08:08
@ping1jing2
Copy link
Copy Markdown
Collaborator

@egvenediktov please resolve lint issue

@ping1jing2
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@OrangeRedeng
Copy link
Copy Markdown
Contributor

@ping1jing2
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ping1jing2
Copy link
Copy Markdown
Collaborator

ping1jing2 commented May 2, 2026

I merged it as only one GPU CI failed due to environment issue
https://github.com/sgl-project/sglang/actions/runs/25151885349/job/73777463688?pr=20520

@ping1jing2 ping1jing2 dismissed ssshinigami’s stale review May 2, 2026 11:28

all issues had been resolved

@sglang-npu-bot sglang-npu-bot merged commit 83bf5d6 into sgl-project:main May 2, 2026
190 of 214 checks passed
@ping1jing2
Copy link
Copy Markdown
Collaborator

@egvenediktov could you please create another PR for document in docs_new, which is the new dir created by community

vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
@egvenediktov egvenediktov deleted the communications_compression branch May 6, 2026 08:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation npu quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants