[Refactor] Reducing code duplication across FP8 CUDA quantization kernels#4163
Merged
zhyncs merged 2 commits intosgl-project:mainfrom Mar 7, 2025
Merged
Conversation
zhyncs
approved these changes
Mar 7, 2025
18 tasks
6 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Refactor quantization kernels to improve code maintainability by:
Part of #2965
Modifications
sglang/sgl-kernel/src/sgl-kernel/include/utils.h:per_token_quant_fp8.cuper_tensor_quant_fp8.cuBenchmark
python /home/jobuser/sglang/sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py INFO 03-07 06:43:21 __init__.py:190] Automatically detected platform cuda. ✅ All implementations match per-tensor-quant-fp8-performance: batch_size seq_len VLLM SGL Kernel 0 16.0 64.0 37.760001 27.264001 1 16.0 128.0 61.280001 40.352002 2 16.0 256.0 119.967997 73.664002 3 16.0 512.0 232.832000 147.136003 4 16.0 1024.0 443.392009 274.271995 5 16.0 2048.0 861.968040 524.160028 6 32.0 64.0 61.087999 40.415999 7 32.0 128.0 119.887993 73.504001 8 32.0 256.0 232.960001 148.031995 9 32.0 512.0 443.744004 274.223983 10 32.0 1024.0 862.128019 524.208009 11 32.0 2048.0 1703.008056 1024.335980 12 64.0 64.0 120.127998 73.472001 13 64.0 128.0 232.991993 148.064002 14 64.0 256.0 443.264008 274.048001 15 64.0 512.0 862.240016 524.576008 16 64.0 1024.0 1706.112027 1023.136020 17 64.0 2048.0 3378.895998 2011.231899 18 128.0 64.0 232.927993 148.016006 19 128.0 128.0 443.455994 273.503989 20 128.0 256.0 861.599982 523.504019 21 128.0 512.0 1706.560016 1023.551941 22 128.0 1024.0 3375.808001 2009.759903 23 128.0 2048.0 6724.080086 3977.504015python /home/jobuser/sglang/sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py ✅ All implementations match per-token-group-quant-fp8-performance: batch_size seq_len group_size Triton SGL Kernel 0 1.0 64.0 128.0 10.848000 9.568000 1 1.0 128.0 128.0 14.848000 12.224000 2 1.0 256.0 128.0 20.160001 14.944000 3 1.0 512.0 128.0 31.776000 21.248000 4 1.0 1024.0 128.0 55.039998 33.920001 5 1.0 2048.0 128.0 101.792000 62.399998 6 2.0 64.0 128.0 14.688000 12.320000 7 2.0 128.0 128.0 20.191999 14.944000 8 2.0 256.0 128.0 31.711999 21.183999 9 2.0 512.0 128.0 55.103999 33.920001 10 2.0 1024.0 128.0 101.712003 62.336002 11 2.0 2048.0 128.0 193.120003 112.864003 12 4.0 64.0 128.0 20.191999 14.912000 13 4.0 128.0 128.0 31.711999 21.183999 14 4.0 256.0 128.0 55.071998 33.920001 15 4.0 512.0 128.0 101.696000 62.431999 16 4.0 1024.0 128.0 193.151996 112.816006 17 4.0 2048.0 128.0 375.200003 212.927997 18 8.0 64.0 128.0 31.679999 21.215999 19 8.0 128.0 128.0 55.103999 33.952001 20 8.0 256.0 128.0 101.696000 62.511995 21 8.0 512.0 128.0 193.248004 112.896003 22 8.0 1024.0 128.0 375.135988 213.088006 23 8.0 2048.0 128.0 738.879979 412.943989 24 16.0 64.0 128.0 55.071998 33.952001 25 16.0 128.0 128.0 101.792000 62.560000 26 16.0 256.0 128.0 193.087995 112.768002 27 16.0 512.0 128.0 375.167996 213.024005 28 16.0 1024.0 128.0 739.232004 413.056016 29 16.0 2048.0 128.0 1467.216015 813.471973 30 32.0 64.0 128.0 101.952001 62.560000 31 32.0 128.0 128.0 193.087995 112.832002 32 32.0 256.0 128.0 375.167996 212.896004 33 32.0 512.0 128.0 739.199996 412.959993 34 32.0 1024.0 128.0 1470.479965 813.567996 35 32.0 2048.0 128.0 2922.784090 1612.895966 36 64.0 64.0 128.0 193.151996 112.896003 37 64.0 128.0 128.0 375.200003 213.152006 38 64.0 256.0 128.0 739.296019 412.831992 39 64.0 512.0 128.0 1471.168041 813.279986 40 64.0 1024.0 128.0 2922.271967 1612.576008 41 64.0 2048.0 128.0 5828.479767 3213.583946python /home/jobuser/sglang/sgl-kernel/benchmark/bench_per_token_quant_fp8.py INFO 03-07 06:46:42 __init__.py:190] Automatically detected platform cuda. ✅ All implementations match per-token-dynamic-quant-fp8-performance: batch_size seq_len VLLM SGL Kernel 0 16.0 64.0 26.144000 30.239999 1 16.0 128.0 44.415999 42.080000 2 16.0 256.0 83.839998 85.055999 3 16.0 512.0 154.944003 157.856002 4 16.0 1024.0 297.152013 290.592015 5 16.0 2048.0 581.471980 551.519990 6 16.0 4096.0 1153.807998 1075.600028 7 32.0 64.0 44.192001 42.080000 8 32.0 128.0 84.608003 83.967999 9 32.0 256.0 155.808002 156.992003 10 32.0 512.0 297.280014 291.424006 11 32.0 1024.0 581.152022 551.072001 12 32.0 2048.0 1152.511954 1074.719906 13 32.0 4096.0 2289.920092 2109.503984 14 64.0 64.0 84.608003 84.096000 15 64.0 128.0 155.727997 158.720002 16 64.0 256.0 297.087997 290.304005 17 64.0 512.0 581.344008 551.168025 18 64.0 1024.0 1153.632045 1074.320078 19 64.0 2048.0 2290.719986 2109.951973 20 64.0 4096.0 4562.816143 4179.647923 21 128.0 64.0 155.616000 157.184005 22 128.0 128.0 297.055990 291.359991 23 128.0 256.0 580.864012 551.072001 24 128.0 512.0 1152.511954 1074.847937 25 128.0 1024.0 2290.240049 2109.247923 26 128.0 2048.0 4562.655926 4180.416107 27 128.0 4096.0 9109.968185 8315.551758Checklist