Support double sparsity#1459
Conversation
|
Great work. Some tips for rebasing:
|
|
Quick question @andy-yang-1 - Does this PR support just Double Sparsity or DS-Offload as well? |
|
@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed. |
9798dc2 to
57c998b
Compare
|
Is there a plan to merge this PR? |
|
Yes. It should be merged within one week.
|
6b07a3d to
5f71afa
Compare
|
Please fix the lint error and add an end-to-end accuracy test |
|
Give two example commands and past their results in the description of this PR. This is for tracking the progress. It should be something like this |
|
@andy-yang-1 Can you also paste the latency results? |
|
@andy-yang-1 Thanks for the contribution. It is merged. |
|
How does one generate the ds-channel-config to be able to use this? |
|
I noticed that CUDA graph is not currently supported. Are there any plans to support it? @andy-yang-1 |
|
@max99x You can use this link to generate channel config file. @fengyang95 We may support it in the next PR |
|
hi @andy-yang-1 Does this support the deepseek-v2 architecture? How can I obtain the config for this structure? I see that the example here https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py only support llama/mixtral arch. |
|
@andy-yang-1 I tried running the deepseek-v2 model, but encountered the following issue: File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/__init__.py", line 49, in forward
return self.forward_extend(q, k, v, layer, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
k_label = torch.gather(
^^^^^^^^^^^^^
RuntimeError: Size does not match at dimension 1 expected index [7, 128, 16] to be smaller than self [7, 1, 576] apart from dimension 2
|
|
@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later |
@andy-yang-1 Thank you very much! Looking forward to support for deepseek-v2 and cuda graph. |
|
@andy-yang-1 - Loved the paper! I was trying this out and I am facing a few issues generating the config file using the mentioned script.
I replaced it with
Any help on how to run this would be appreciated. |
|
@shreyansh26 The first problem is caused by older version of transformers, and I will update the base repo to fix it this week. |
|
Thank you. But in the Llama-3.1-8B-Instruct config file, |
|
@shreyansh26 Hi, I have updated the main repo. Can you try with this code? |
|
Thank you @andy-yang-1!! This is working perfectly now. |
|
I found that the throughput of prefill is lower when enable DS attention(from 6543.55 to 4189.16 ). The possible reason is that you use triton as attention-backend. Is it possible to use flashinfer attention in prefill to increase the throughput of prefill. |
|
Hi, @andy-yang-1 , that is a great work! CUDA_VISIBLE_DEVICES=0 python -m sglang.bench_one_batch --model-path /home/lyy/model /root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/cuda/init.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you. [2025-10-09 14:21:21 TP0] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=33.37 GB, mem usage=13.57 GB. |
|
Thank you for your excellent work! I have 3 questions I'd like to ask: (1) Performance Comparison (2) Model Differences (3) Main Advantage Scenarios of Double Sparsity I would greatly appreciate it if you could spare some time to answer these questions. Thank you in advance! |
Motivation
Modifications
sglang/python/sglang/srt/layers/sparse_decode_attention.pySpeedup Evaluation
Run double sparsity with:
python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \ --attention-backend triton --disable-cuda-graph \ --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \ --input-len 20000 --output-len 200 \ --batch-size 3 \ --enable-double-sparsity \ --ds-heavy-channel-num 16 \ --ds-heavy-token-num 1024 \ --ds-sparse-decode-threshold 0 \ --max-total-tokens 70000 Benchmark ... Prefill. latency: 7.83636 s, throughput: 7656.62 token/s Decode. latency: 0.02351 s, throughput: 127.58 token/s Decode. latency: 0.02124 s, throughput: 141.22 token/s Decode. latency: 0.02037 s, throughput: 147.26 token/s Decode. latency: 0.01950 s, throughput: 153.81 token/s Decode. latency: 0.01935 s, throughput: 155.04 token/s Decode. median latency: 0.01923 s, median throughput: 156.04 token/s Total. latency: 11.821 s, throughput: 5126.36 token/sOriginal triton implementation:
python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \ --attention-backend triton \ --input-len 20000 --output-len 200 \ --batch-size 3 Benchmark ... Prefill. latency: 7.79627 s, throughput: 7695.98 token/s Decode. latency: 0.07196 s, throughput: 41.69 token/s Decode. latency: 0.06514 s, throughput: 46.05 token/s Decode. latency: 0.06475 s, throughput: 46.33 token/s Decode. latency: 0.06463 s, throughput: 46.41 token/s Decode. latency: 0.06457 s, throughput: 46.46 token/s Decode. median latency: 0.06487 s, median throughput: 46.25 token/s Total. latency: 20.720 s, throughput: 2924.74 token/sOriginal flashinfer implementation:
python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \ --attention-backend flashinfer \ --input-len 20000 --output-len 200 \ --batch-size 3 Benchmark ... Prefill. latency: 5.68892 s, throughput: 10546.83 token/s Decode. latency: 0.03240 s, throughput: 92.60 token/s Decode. latency: 0.02993 s, throughput: 100.23 token/s Decode. latency: 0.02970 s, throughput: 101.01 token/s Decode. latency: 0.02959 s, throughput: 101.39 token/s Decode. latency: 0.02959 s, throughput: 101.38 token/s Decode. median latency: 0.02961 s, median throughput: 101.32 token/s Total. latency: 11.585 s, throughput: 5231.00 token/sWith Llama-3.1-8B:
Checklist