Skip to content

[Refactor] Reducing code duplication across FP8 CUDA quantization kernels#4163

Merged
zhyncs merged 2 commits intosgl-project:mainfrom
hebiao064:reuse_utils_for_per_token_quant
Mar 7, 2025
Merged

[Refactor] Reducing code duplication across FP8 CUDA quantization kernels#4163
zhyncs merged 2 commits intosgl-project:mainfrom
hebiao064:reuse_utils_for_per_token_quant

Conversation

@hebiao064
Copy link
Copy Markdown
Collaborator

@hebiao064 hebiao064 commented Mar 7, 2025

Motivation

Refactor quantization kernels to improve code maintainability by:

  • Centralizing shared CUDA utility functions and type definitions in utils.h
  • Reducing code duplication across quantization kernels

Part of #2965

Modifications

  • Moved shared code to sglang/sgl-kernel/src/sgl-kernel/include/utils.h:
    • FP8 type definitions and constants
    • CUDA warp utility functions (warpReduceMax, atomicMaxFloat)
    • Common preprocessor definitions
  • Updated quantization kernels to use shared code:
    • per_token_quant_fp8.cu
    • per_tensor_quant_fp8.cu

Benchmark

python /home/jobuser/sglang/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
INFO 03-07 06:43:21 __init__.py:190] Automatically detected platform cuda.
✅ All implementations match
per-tensor-quant-fp8-performance:
    batch_size  seq_len         VLLM   SGL Kernel
0         16.0     64.0    37.760001    27.264001
1         16.0    128.0    61.280001    40.352002
2         16.0    256.0   119.967997    73.664002
3         16.0    512.0   232.832000   147.136003
4         16.0   1024.0   443.392009   274.271995
5         16.0   2048.0   861.968040   524.160028
6         32.0     64.0    61.087999    40.415999
7         32.0    128.0   119.887993    73.504001
8         32.0    256.0   232.960001   148.031995
9         32.0    512.0   443.744004   274.223983
10        32.0   1024.0   862.128019   524.208009
11        32.0   2048.0  1703.008056  1024.335980
12        64.0     64.0   120.127998    73.472001
13        64.0    128.0   232.991993   148.064002
14        64.0    256.0   443.264008   274.048001
15        64.0    512.0   862.240016   524.576008
16        64.0   1024.0  1706.112027  1023.136020
17        64.0   2048.0  3378.895998  2011.231899
18       128.0     64.0   232.927993   148.016006
19       128.0    128.0   443.455994   273.503989
20       128.0    256.0   861.599982   523.504019
21       128.0    512.0  1706.560016  1023.551941
22       128.0   1024.0  3375.808001  2009.759903
23       128.0   2048.0  6724.080086  3977.504015
python /home/jobuser/sglang/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
✅ All implementations match
per-token-group-quant-fp8-performance:
    batch_size  seq_len  group_size       Triton   SGL Kernel
0          1.0     64.0       128.0    10.848000     9.568000
1          1.0    128.0       128.0    14.848000    12.224000
2          1.0    256.0       128.0    20.160001    14.944000
3          1.0    512.0       128.0    31.776000    21.248000
4          1.0   1024.0       128.0    55.039998    33.920001
5          1.0   2048.0       128.0   101.792000    62.399998
6          2.0     64.0       128.0    14.688000    12.320000
7          2.0    128.0       128.0    20.191999    14.944000
8          2.0    256.0       128.0    31.711999    21.183999
9          2.0    512.0       128.0    55.103999    33.920001
10         2.0   1024.0       128.0   101.712003    62.336002
11         2.0   2048.0       128.0   193.120003   112.864003
12         4.0     64.0       128.0    20.191999    14.912000
13         4.0    128.0       128.0    31.711999    21.183999
14         4.0    256.0       128.0    55.071998    33.920001
15         4.0    512.0       128.0   101.696000    62.431999
16         4.0   1024.0       128.0   193.151996   112.816006
17         4.0   2048.0       128.0   375.200003   212.927997
18         8.0     64.0       128.0    31.679999    21.215999
19         8.0    128.0       128.0    55.103999    33.952001
20         8.0    256.0       128.0   101.696000    62.511995
21         8.0    512.0       128.0   193.248004   112.896003
22         8.0   1024.0       128.0   375.135988   213.088006
23         8.0   2048.0       128.0   738.879979   412.943989
24        16.0     64.0       128.0    55.071998    33.952001
25        16.0    128.0       128.0   101.792000    62.560000
26        16.0    256.0       128.0   193.087995   112.768002
27        16.0    512.0       128.0   375.167996   213.024005
28        16.0   1024.0       128.0   739.232004   413.056016
29        16.0   2048.0       128.0  1467.216015   813.471973
30        32.0     64.0       128.0   101.952001    62.560000
31        32.0    128.0       128.0   193.087995   112.832002
32        32.0    256.0       128.0   375.167996   212.896004
33        32.0    512.0       128.0   739.199996   412.959993
34        32.0   1024.0       128.0  1470.479965   813.567996
35        32.0   2048.0       128.0  2922.784090  1612.895966
36        64.0     64.0       128.0   193.151996   112.896003
37        64.0    128.0       128.0   375.200003   213.152006
38        64.0    256.0       128.0   739.296019   412.831992
39        64.0    512.0       128.0  1471.168041   813.279986
40        64.0   1024.0       128.0  2922.271967  1612.576008
41        64.0   2048.0       128.0  5828.479767  3213.583946
python /home/jobuser/sglang/sgl-kernel/benchmark/bench_per_token_quant_fp8.py
INFO 03-07 06:46:42 __init__.py:190] Automatically detected platform cuda.
✅ All implementations match
per-token-dynamic-quant-fp8-performance:
    batch_size  seq_len         VLLM   SGL Kernel
0         16.0     64.0    26.144000    30.239999
1         16.0    128.0    44.415999    42.080000
2         16.0    256.0    83.839998    85.055999
3         16.0    512.0   154.944003   157.856002
4         16.0   1024.0   297.152013   290.592015
5         16.0   2048.0   581.471980   551.519990
6         16.0   4096.0  1153.807998  1075.600028
7         32.0     64.0    44.192001    42.080000
8         32.0    128.0    84.608003    83.967999
9         32.0    256.0   155.808002   156.992003
10        32.0    512.0   297.280014   291.424006
11        32.0   1024.0   581.152022   551.072001
12        32.0   2048.0  1152.511954  1074.719906
13        32.0   4096.0  2289.920092  2109.503984
14        64.0     64.0    84.608003    84.096000
15        64.0    128.0   155.727997   158.720002
16        64.0    256.0   297.087997   290.304005
17        64.0    512.0   581.344008   551.168025
18        64.0   1024.0  1153.632045  1074.320078
19        64.0   2048.0  2290.719986  2109.951973
20        64.0   4096.0  4562.816143  4179.647923
21       128.0     64.0   155.616000   157.184005
22       128.0    128.0   297.055990   291.359991
23       128.0    256.0   580.864012   551.072001
24       128.0    512.0  1152.511954  1074.847937
25       128.0   1024.0  2290.240049  2109.247923
26       128.0   2048.0  4562.655926  4180.416107
27       128.0   4096.0  9109.968185  8315.551758

Checklist

@hebiao064 hebiao064 changed the title [Refactor] Reuse shared code in utils for quant kernels [Refactor] Reuse shared code in utils for fp8 quant kernels Mar 7, 2025
@hebiao064 hebiao064 marked this pull request as ready for review March 7, 2025 06:53
@hebiao064 hebiao064 changed the title [Refactor] Reuse shared code in utils for fp8 quant kernels [Refactor] Reducing code duplication across FP8 CUDA quantization kernels Mar 7, 2025
@zhyncs zhyncs merged commit 95085d6 into sgl-project:main Mar 7, 2025
@BBuf BBuf mentioned this pull request Mar 7, 2025
18 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants