Skip to content

Support double sparsity#1459

Merged
merrymercy merged 6 commits intosgl-project:mainfrom
andy-yang-1:double-sparsity
Oct 14, 2024
Merged

Support double sparsity#1459
merrymercy merged 6 commits intosgl-project:mainfrom
andy-yang-1:double-sparsity

Conversation

@andy-yang-1
Copy link
Copy Markdown
Contributor

@andy-yang-1 andy-yang-1 commented Sep 18, 2024

Motivation

  • Support double sparsity (post-training sparse attention) for long context inference in SGLang
  • See paper

Modifications

  • Add triton implementation in sglang/python/sglang/srt/layers/sparse_decode_attention.py
  • Add serving-related parts

Speedup Evaluation

Run double sparsity with:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton --disable-cuda-graph \
    --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \
    --input-len 20000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 16 \
    --ds-heavy-token-num 1024 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 70000

Benchmark ...
Prefill. latency: 7.83636 s, throughput:   7656.62 token/s
Decode.  latency: 0.02351 s, throughput:    127.58 token/s
Decode.  latency: 0.02124 s, throughput:    141.22 token/s
Decode.  latency: 0.02037 s, throughput:    147.26 token/s
Decode.  latency: 0.01950 s, throughput:    153.81 token/s
Decode.  latency: 0.01935 s, throughput:    155.04 token/s
Decode.  median latency: 0.01923 s, median throughput:    156.04 token/s
Total. latency: 11.821 s, throughput:   5126.36 token/s

Original triton implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 7.79627 s, throughput:   7695.98 token/s
Decode.  latency: 0.07196 s, throughput:     41.69 token/s
Decode.  latency: 0.06514 s, throughput:     46.05 token/s
Decode.  latency: 0.06475 s, throughput:     46.33 token/s
Decode.  latency: 0.06463 s, throughput:     46.41 token/s
Decode.  latency: 0.06457 s, throughput:     46.46 token/s
Decode.  median latency: 0.06487 s, median throughput:     46.25 token/s
Total. latency: 20.720 s, throughput:   2924.74 token/s

Original flashinfer implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend flashinfer \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 5.68892 s, throughput:  10546.83 token/s
Decode.  latency: 0.03240 s, throughput:     92.60 token/s
Decode.  latency: 0.02993 s, throughput:    100.23 token/s
Decode.  latency: 0.02970 s, throughput:    101.01 token/s
Decode.  latency: 0.02959 s, throughput:    101.39 token/s
Decode.  latency: 0.02959 s, throughput:    101.38 token/s
Decode.  median latency: 0.02961 s, median throughput:    101.32 token/s
Total. latency: 11.585 s, throughput:   5231.00 token/s

With Llama-3.1-8B:

# Double Sparsity
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --ds-channel-config-path /path/to/meta-llama/Llama-3.1-8B-Instruct.json \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 32 \
    --ds-heavy-channel-type k \
    --ds-heavy-token-num 3000 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 42.96801 s, throughput:   4189.16 token/s
Decode.  latency: 0.02843 s, throughput:    105.50 token/s
Decode.  latency: 0.02518 s, throughput:    119.16 token/s
Decode.  latency: 0.02465 s, throughput:    121.72 token/s
Decode.  latency: 0.02442 s, throughput:    122.84 token/s
Decode.  latency: 0.02434 s, throughput:    123.24 token/s
Decode.  median latency: 0.02421 s, median throughput:    123.90 token/s
Total. latency: 47.793 s, throughput:   3778.77 token/s

# Triton
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 43.17160 s, throughput:   4169.41 token/s
Decode.  latency: 0.06359 s, throughput:     47.18 token/s
Decode.  latency: 0.05965 s, throughput:     50.30 token/s
Decode.  latency: 0.05927 s, throughput:     50.62 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  median latency: 0.05913 s, median throughput:     50.73 token/s
Total. latency: 54.950 s, throughput:   3286.63 token/s

# Flashinfer
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend flashinfer \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 27.50800 s, throughput:   6543.55 token/s
Decode.  latency: 0.03014 s, throughput:     99.54 token/s
Decode.  latency: 0.02834 s, throughput:    105.86 token/s
Decode.  latency: 0.02821 s, throughput:    106.36 token/s
Decode.  latency: 0.02819 s, throughput:    106.41 token/s
Decode.  latency: 0.02823 s, throughput:    106.28 token/s
Decode.  median latency: 0.02821 s, median throughput:    106.34 token/s
Total. latency: 33.125 s, throughput:   5452.12 token/s

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@merrymercy
Copy link
Copy Markdown
Contributor

merrymercy commented Sep 19, 2024

Great work. Some tips for rebasing:

Comment thread python/sglang/srt/layers/radix_attention.py Outdated
Comment thread python/sglang/srt/layers/test_ds_kernel.py Outdated
Comment thread python/sglang/srt/mem_cache/memory_pool.py Outdated
Comment thread python/sglang/srt/model_executor/model_runner.py Outdated
@Ying1123 Ying1123 mentioned this pull request Sep 22, 2024
37 tasks
@merrymercy merrymercy mentioned this pull request Sep 22, 2024
2 tasks
@ghost
Copy link
Copy Markdown

ghost commented Sep 24, 2024

Quick question @andy-yang-1 - Does this PR support just Double Sparsity or DS-Offload as well?

@andy-yang-1
Copy link
Copy Markdown
Contributor Author

@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed.

@fengyang95
Copy link
Copy Markdown

Is there a plan to merge this PR?

@merrymercy
Copy link
Copy Markdown
Contributor

merrymercy commented Oct 11, 2024

Yes. It should be merged within one week.
@andy-yang-1 please

  1. Resolve the conflicts.
  2. Add an end-to-end accuracy unit test

@merrymercy
Copy link
Copy Markdown
Contributor

Please fix the lint error and add an end-to-end accuracy test

Comment thread python/sglang/srt/model_executor/forward_batch_info.py Outdated
Comment thread python/sglang/test/Llama-3.1-8B-Instruct.jsonconfig
Comment thread test/srt/test_double_sparsity.py
@merrymercy merrymercy changed the title [WIP] Support double sparsity Support double sparsity Oct 14, 2024
Comment thread test/srt/test_double_sparsity.py Outdated
Comment thread test/srt/test_double_sparsity.py
@merrymercy
Copy link
Copy Markdown
Contributor

merrymercy commented Oct 14, 2024

Give two example commands and past their results in the description of this PR. This is for tracking the progress. It should be something like this

# baseline
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8

# double sparsity
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8 --enable-double-sparsity ...

@merrymercy
Copy link
Copy Markdown
Contributor

@andy-yang-1 Can you also paste the latency results?

Comment thread test/srt/test_double_sparsity.py Outdated
Comment thread test/srt/run_suite.py Outdated
@merrymercy merrymercy enabled auto-merge (squash) October 14, 2024 08:32
@merrymercy merrymercy disabled auto-merge October 14, 2024 09:00
@merrymercy merrymercy merged commit 061e546 into sgl-project:main Oct 14, 2024
@merrymercy
Copy link
Copy Markdown
Contributor

@andy-yang-1 Thanks for the contribution. It is merged.

@max99x
Copy link
Copy Markdown
Contributor

max99x commented Oct 14, 2024

How does one generate the ds-channel-config to be able to use this?

@fengyang95
Copy link
Copy Markdown

I noticed that CUDA graph is not currently supported. Are there any plans to support it? @andy-yang-1

@andy-yang-1
Copy link
Copy Markdown
Contributor Author

@max99x You can use this link to generate channel config file.

@fengyang95 We may support it in the next PR

@fengyang95
Copy link
Copy Markdown

fengyang95 commented Oct 18, 2024

hi @andy-yang-1 Does this support the deepseek-v2 architecture? How can I obtain the config for this structure? I see that the example here https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py only support llama/mixtral arch.

@fengyang95
Copy link
Copy Markdown

fengyang95 commented Oct 19, 2024

@andy-yang-1 I tried running the deepseek-v2 model, but encountered the following issue:

File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
    k_label = torch.gather(
              ^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument index in method wrapper_CUDA_gather)
  File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/__init__.py", line 49, in forward
    return self.forward_extend(q, k, v, layer, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
    k_label = torch.gather(
              ^^^^^^^^^^^^^
RuntimeError: Size does not match at dimension 1 expected index [7, 128, 16] to be smaller than self [7, 1, 576] apart from dimension 2

@andy-yang-1
Copy link
Copy Markdown
Contributor Author

@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later

@fengyang95
Copy link
Copy Markdown

@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later

@andy-yang-1 Thank you very much! Looking forward to support for deepseek-v2 and cuda graph.

@shreyansh26
Copy link
Copy Markdown

@andy-yang-1 - Loved the paper! I was trying this out and I am facing a few issues generating the config file using the mentioned script.

  1. The line cos, sin = m.rotary_emb(v, seq_len=kv_seq_len) in stat_qk_max_hook of get_calib_qk_feat gives an error
TypeError: LlamaRotaryEmbedding got an unexpected keyword argument 'seq_len'

I replaced it with cos, sin = m.rotary_emb(v, position_ids=position_ids) which works. I'm not sure if that is correct but LlamaRotaryEmbedding indeed doesn't have the seq_len param

  1. In the config file that gets generated, I only get keys of the form model.layers.{layer_num}.self_attn but the config file present in the test folder has keys in the form of model.layers.{layer_num}.self_attn.q_proj, model.layers.{layer_num}.self_attn.k_proj and model.layers.{layer_num}.self_attn.qk_proj. How were these generated?
    On using my generated config with sglang, I am getting error of the type - Key model.layers.0.self_attn.k_proj was not found.

Any help on how to run this would be appreciated.

@andy-yang-1
Copy link
Copy Markdown
Contributor Author

@shreyansh26 The first problem is caused by older version of transformers, and I will update the base repo to fix it this week.
The q_outlier_config/k_outlier_config is generated with get_calib_feat function, and the qk_outlier_config is generated with get_qk_calib_feat function. You can merge this two config together to get all configs. I will also update it this week.

@shreyansh26
Copy link
Copy Markdown

shreyansh26 commented Nov 7, 2024

Thank you.
There may be another discrepancy, in get_calib_feat, with the following condition, k_proj gets filtered out because of GQA.

if y.shape[-1] != model.config.hidden_size:
    return

But in the Llama-3.1-8B-Instruct config file, k_proj keys are also present.

@andy-yang-1
Copy link
Copy Markdown
Contributor Author

@shreyansh26 Hi, I have updated the main repo. Can you try with this code?

@shreyansh26
Copy link
Copy Markdown

Thank you @andy-yang-1!! This is working perfectly now.

@yuguo-Jack
Copy link
Copy Markdown

@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed.
Is there a plan to support DS-Offload in Sglang?

@hcyz33
Copy link
Copy Markdown
Contributor

hcyz33 commented Jan 13, 2025

Motivation

  • Support double sparsity (post-training sparse attention) for long context inference in SGLang
  • See paper

Modifications

  • Add triton implementation in sglang/python/sglang/srt/layers/sparse_decode_attention.py
  • Add serving-related parts

Speedup Evaluation

Run double sparsity with:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton --disable-cuda-graph \
    --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \
    --input-len 20000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 16 \
    --ds-heavy-token-num 1024 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 70000

Benchmark ...
Prefill. latency: 7.83636 s, throughput:   7656.62 token/s
Decode.  latency: 0.02351 s, throughput:    127.58 token/s
Decode.  latency: 0.02124 s, throughput:    141.22 token/s
Decode.  latency: 0.02037 s, throughput:    147.26 token/s
Decode.  latency: 0.01950 s, throughput:    153.81 token/s
Decode.  latency: 0.01935 s, throughput:    155.04 token/s
Decode.  median latency: 0.01923 s, median throughput:    156.04 token/s
Total. latency: 11.821 s, throughput:   5126.36 token/s

Original triton implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 7.79627 s, throughput:   7695.98 token/s
Decode.  latency: 0.07196 s, throughput:     41.69 token/s
Decode.  latency: 0.06514 s, throughput:     46.05 token/s
Decode.  latency: 0.06475 s, throughput:     46.33 token/s
Decode.  latency: 0.06463 s, throughput:     46.41 token/s
Decode.  latency: 0.06457 s, throughput:     46.46 token/s
Decode.  median latency: 0.06487 s, median throughput:     46.25 token/s
Total. latency: 20.720 s, throughput:   2924.74 token/s

Original flashinfer implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend flashinfer \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 5.68892 s, throughput:  10546.83 token/s
Decode.  latency: 0.03240 s, throughput:     92.60 token/s
Decode.  latency: 0.02993 s, throughput:    100.23 token/s
Decode.  latency: 0.02970 s, throughput:    101.01 token/s
Decode.  latency: 0.02959 s, throughput:    101.39 token/s
Decode.  latency: 0.02959 s, throughput:    101.38 token/s
Decode.  median latency: 0.02961 s, median throughput:    101.32 token/s
Total. latency: 11.585 s, throughput:   5231.00 token/s

With Llama-3.1-8B:

# Double Sparsity
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --ds-channel-config-path /path/to/meta-llama/Llama-3.1-8B-Instruct.json \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 32 \
    --ds-heavy-channel-type k \
    --ds-heavy-token-num 3000 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 42.96801 s, throughput:   4189.16 token/s
Decode.  latency: 0.02843 s, throughput:    105.50 token/s
Decode.  latency: 0.02518 s, throughput:    119.16 token/s
Decode.  latency: 0.02465 s, throughput:    121.72 token/s
Decode.  latency: 0.02442 s, throughput:    122.84 token/s
Decode.  latency: 0.02434 s, throughput:    123.24 token/s
Decode.  median latency: 0.02421 s, median throughput:    123.90 token/s
Total. latency: 47.793 s, throughput:   3778.77 token/s

# Triton
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 43.17160 s, throughput:   4169.41 token/s
Decode.  latency: 0.06359 s, throughput:     47.18 token/s
Decode.  latency: 0.05965 s, throughput:     50.30 token/s
Decode.  latency: 0.05927 s, throughput:     50.62 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  median latency: 0.05913 s, median throughput:     50.73 token/s
Total. latency: 54.950 s, throughput:   3286.63 token/s

# Flashinfer
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend flashinfer \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 27.50800 s, throughput:   6543.55 token/s
Decode.  latency: 0.03014 s, throughput:     99.54 token/s
Decode.  latency: 0.02834 s, throughput:    105.86 token/s
Decode.  latency: 0.02821 s, throughput:    106.36 token/s
Decode.  latency: 0.02819 s, throughput:    106.41 token/s
Decode.  latency: 0.02823 s, throughput:    106.28 token/s
Decode.  median latency: 0.02821 s, median throughput:    106.34 token/s
Total. latency: 33.125 s, throughput:   5452.12 token/s

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

I found that the throughput of prefill is lower when enable DS attention(from 6543.55 to 4189.16 ). The possible reason is that you use triton as attention-backend. Is it possible to use flashinfer attention in prefill to increase the throughput of prefill.

@alex1720-web
Copy link
Copy Markdown

Hi, @andy-yang-1 , that is a great work!
I have encountered a problem when using double sparsity in the latest version of SGLang. I followed the same command as you, but it turned out to fail. The error log is in the following. Could you help me to fix it ?

CUDA_VISIBLE_DEVICES=0 python -m sglang.bench_one_batch --model-path /home/lyy/model
s/Mistral-7B-v0.1 --tensor-parallel-size 1 --attention-backend triton --disable-cuda-graph --ds-channel-config-path /home/zongyi/DoubleSparse/config/mistralai/Mistral-7B-v0.1.json --input-len 20000 --output-len 200 --batch-size 1 --enable-double-sparsity --ds-heavy-channel-num 16 --ds-heavy-token-num 1024 --ds-sparse-decode-threshold 0 --max-total-tokens 70000

/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/cuda/init.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
torch_dtype is deprecated! Use dtype instead!
[2025-10-09 14:21:16 TP0] Double sparsity optimization is turned on. Use triton backend without CUDA graph.
[2025-10-09 14:21:16 TP0] Init torch distributed begin.
[rank0]:[W1009 14:21:17.340194889 ProcessGroupGloo.cpp:514] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-10-09 14:21:17 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-10-09 14:21:18 TP0] Ignore import error when loading sglang.srt.models.kimi_vl: cannot import name 'GELUTanh' from 'transformers.activations' (/home/zongyi/transformers/activations.py)
[2025-10-09 14:21:18 TP0] Ignore import error when loading sglang.srt.models.kimi_vl_moonvit: cannot import name 'GELUTanh' from 'transformers.activations' (/home/zongyi/transformers/activations.py)
[2025-10-09 14:21:18 TP0] Load weight begin. avail mem=46.94 GB
Loading safetensors checkpoint shards: 0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 33% Completed | 1/3 [00:00<00:01, 1.30it/s]
Loading safetensors checkpoint shards: 67% Completed | 2/3 [00:01<00:00, 1.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00, 1.19it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00, 1.21it/s]

[2025-10-09 14:21:21 TP0] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=33.37 GB, mem usage=13.57 GB.
[2025-10-09 14:21:21 TP0] Using KV cache dtype: torch.bfloat16
[2025-10-09 14:21:21 TP0] Memory pool end. avail mem=23.76 GB
max_total_num_tokens=70000
Warmup ...
[rank0]: Traceback (most recent call last):
[rank0]: File "", line 198, in _run_module_as_main
[rank0]: File "", line 88, in _run_code
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 660, in
[rank0]: main(server_args, bench_args)
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 624, in main
[rank0]: work_func(server_args, port_args, bench_args, 0)
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 531, in latency_test
[rank0]: latency_test_run_once(
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 433, in latency_test_run_once
[rank0]: next_token_ids, _, batch = extend(reqs, model_runner)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 275, in extend
[rank0]: logits_output, _ = model_runner.forward(forward_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 1982, in forward
[rank0]: output = self._forward_raw(
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 2033, in _forward_raw
[rank0]: ret = self.forward_extend(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 1927, in forward_extend
[rank0]: return self.model.forward(
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 469, in forward
[rank0]: hidden_states = self.model(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 342, in forward
[rank0]: hidden_states, residual = layer(
[rank0]: ^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 266, in forward
[rank0]: hidden_states = self.self_attn(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 197, in forward
[rank0]: attn_output = self.attn(q, k, v, forward_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/layers/radix_attention.py", line 108, in forward
[rank0]: return forward_batch.attn_backend.forward(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/layers/attention/base_attn_backend.py", line 82, in forward
[rank0]: return self.forward_extend(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/layers/attention/double_sparsity_backend.py", line 128, in forward_extend
[rank0]: k_label = torch.gather(
[rank0]: ^^^^^^^^^^^^^
[rank0]: RuntimeError: Size does not match at dimension 1 expected index [20000, 32, 16] to be no larger than self [20000, 8, 128] apart from dimension 2
[rank0]:[W1009 14:21:23.273092827 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

@IgniteGo
Copy link
Copy Markdown

Thank you for your excellent work! I have 3 questions I'd like to ask:

(1) Performance Comparison
Based on the test results you provided in the PR and my own local tests, Double Sparsity shows a slight improvement over SGLang + Triton, but compared to SGLang + FlashInfer (the default backend), there is no significant degradation (and sometimes it's roughly on par).
Is this mainly due to inherent efficiency limitations of the Triton sparse operator itself? Or are factors such as kernel optimization, lack of CUDA Graph support, etc., having a bigger impact?

(2) Model Differences
In the PR benchmarks, the relative improvement of Double Sparsity on LongChat-7B is clearly larger than on Llama-3.1-8B. What is the main reason for this difference?

(3) Main Advantage Scenarios of Double Sparsity
What are the primary scenarios where Double Sparsity shows the strongest advantages?
How does its speedup correlate with hardware (compute vs. memory bandwidth on H100/A100), input/output lengths, batch size, and model architecture (e.g., GQA, context length)?

I would greatly appreciate it if you could spare some time to answer these questions. Thank you in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants