Skip to content

Support cuda graph in the triton attention backend#1401

Merged
merrymercy merged 4 commits intomainfrom
triton-cuda-graph
Sep 12, 2024
Merged

Support cuda graph in the triton attention backend#1401
merrymercy merged 4 commits intomainfrom
triton-cuda-graph

Conversation

@merrymercy
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy commented Sep 12, 2024

Llama 3 8B (1.3x faster)

# triton w/ cuda graph
# Decode.  median latency: 0.00706 s, median throughput:    141.63 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend triton

# triton w/o cuda graph
# Decode.  median latency: 0.00928 s, median throughput:    107.79 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend triton --disable-cuda-graph


# flashinfer w/ cuda graph
# Decode.  median latency: 0.00735 s, median throughput:    135.98 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend flashinfer

# flashinfer w/o cuda graph
# Decode.  median latency: 0.00823 s, median throughput:    121.46 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend flashinfer --disable-cuda-graph

DeepSeek-Coder-V2-Lite (4x faster)

# triton w/ cuda graph
# Decode.  median latency: 0.00622 s, median throughput:    160.82 token/s
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote --batch-size 1 --input 128 --output 8 --enable-mla

# triton w/o cuda graph
# Decode.  median latency: 0.02453 s, median throughput:     40.77 token/s
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote --batch-size 1 --input 128 --output 8 --enable-mla --disable-cuda-graph

@merrymercy merrymercy merged commit 3efa798 into main Sep 12, 2024
@merrymercy merrymercy deleted the triton-cuda-graph branch September 12, 2024 07:36
@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 12, 2024

Significant improvement, especially in small batch latency. Accuracy is similar to before.

ref #1285 (comment)

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --enable-mla --trust-remote-code --disable-radix

lm_eval --model local-completions --tasks gsm8k --model_args model=deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False
# run 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7695|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.7559|±  |0.0118|

# run 2
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7801|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7688|±  |0.0116|

# run 3
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7741|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7672|±  |0.0116|

The impact on max throughput is not significant, because after enabling CUDA Graph, TP 1 needs to adjust --mem-frac 0.85, otherwise it will result in OOM.

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --enable-mla --trust-remote-code --disable-radix --mem-static 0.85
python3 -m sglang.bench_serving --backend sglang --num-prompts 5000 

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 12, 2024

python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code --disable-cuda-graph
Decode.  median latency: 0.00793 s, median throughput:    126.09 token/s
Decode.  median latency: 0.03645 s, median throughput:     27.44 token/s

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 12, 2024

python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code --enable-mla
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code --enable-mla --disable-cuda-graph
Decode.  median latency: 0.00621 s, median throughput:    161.09 token/s
Decode.  median latency: 0.01916 s, median throughput:     52.19 token/s

@fengyang95
Copy link
Copy Markdown

fengyang95 commented Sep 13, 2024

Hi @zhyncs @merrymercy Does this support sm_89 (L40)? I see that cuda graph relies on vllm's fused_moe, but from what I can see, it seems that it does not support sm_89?

@merrymercy
Copy link
Copy Markdown
Contributor Author

@fengyang95 It should support L40 but I haven't tested it. I think cuda graph does not depend on specific ops. It just captures the existing ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants