Skip to content

Commit 2cdde5d

Browse files
celveBBuf
andauthored
[Kernel Slimming] Migrate AWQ marlin repack kernel to JIT (sgl-project#18949)
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
1 parent e0e0cad commit 2cdde5d

11 files changed

Lines changed: 1336 additions & 1 deletion
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
8+
9+
if TYPE_CHECKING:
10+
from tvm_ffi.module import Module
11+
12+
13+
@cache_once
14+
def _jit_awq_dequantize_module(dtype: torch.dtype) -> Module:
15+
args = make_cpp_args(dtype)
16+
return load_jit(
17+
"awq_dequantize",
18+
*args,
19+
cuda_files=["gemm/awq_dequantize.cuh"],
20+
cuda_wrappers=[("awq_dequantize", f"awq_dequantize<{args}>")],
21+
)
22+
23+
24+
def awq_dequantize(
25+
qweight: torch.Tensor,
26+
scales: torch.Tensor,
27+
qzeros: torch.Tensor,
28+
) -> torch.Tensor:
29+
qweight_rows = qweight.shape[0]
30+
qweight_cols = qweight.shape[1]
31+
output = torch.empty(
32+
(qweight_rows, qweight_cols * 8),
33+
dtype=scales.dtype,
34+
device=scales.device,
35+
)
36+
module = _jit_awq_dequantize_module(scales.dtype)
37+
module.awq_dequantize(output, qweight, scales, qzeros)
38+
return output
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from sglang.jit_kernel.utils import cache_once, load_jit
8+
9+
if TYPE_CHECKING:
10+
from tvm_ffi.module import Module
11+
12+
13+
@cache_once
14+
def _jit_awq_marlin_repack_module() -> Module:
15+
return load_jit(
16+
"awq_marlin_repack",
17+
cuda_files=["gemm/marlin/awq_marlin_repack.cuh"],
18+
cuda_wrappers=[("awq_marlin_repack", "awq_marlin_repack")],
19+
)
20+
21+
22+
def awq_marlin_repack(
23+
b_q_weight: torch.Tensor,
24+
size_k: int,
25+
size_n: int,
26+
num_bits: int,
27+
) -> torch.Tensor:
28+
tile_size = 16
29+
pack_factor = 32 // num_bits
30+
out = torch.empty(
31+
(size_k // tile_size, size_n * tile_size // pack_factor),
32+
dtype=b_q_weight.dtype,
33+
device=b_q_weight.device,
34+
)
35+
module = _jit_awq_marlin_repack_module()
36+
module.awq_marlin_repack(out, b_q_weight, size_k, size_n, num_bits)
37+
return out
38+
39+
40+
def awq_marlin_moe_repack(
41+
b_q_weight: torch.Tensor,
42+
perm: torch.Tensor,
43+
size_k: int,
44+
size_n: int,
45+
num_bits: int,
46+
) -> torch.Tensor:
47+
num_experts = b_q_weight.shape[0]
48+
assert size_k % 16 == 0
49+
output = torch.empty(
50+
(num_experts, size_k // 16, size_n * (num_bits // 2)),
51+
device=b_q_weight.device,
52+
dtype=b_q_weight.dtype,
53+
)
54+
for e in range(num_experts):
55+
output[e] = awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
56+
return output
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import itertools
2+
import os
3+
4+
import torch
5+
import triton
6+
import triton.testing
7+
8+
from sglang.jit_kernel.awq_dequantize import awq_dequantize as jit_awq_dequantize
9+
10+
try:
11+
from sgl_kernel import awq_dequantize as aot_awq_dequantize
12+
13+
AOT_AVAILABLE = True
14+
except ImportError:
15+
AOT_AVAILABLE = False
16+
17+
IS_CI = (
18+
os.getenv("CI", "false").lower() == "true"
19+
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
20+
)
21+
22+
# CI environment uses simplified parameters
23+
if IS_CI:
24+
qweight_row_range = [128]
25+
qweight_cols_range = [16]
26+
else:
27+
qweight_row_range = [128, 256, 512, 1024, 3584]
28+
qweight_cols_range = [16, 32, 64, 128, 448]
29+
30+
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
31+
32+
33+
def check_correctness():
34+
if not AOT_AVAILABLE:
35+
print("sgl_kernel AOT not available, skipping correctness check")
36+
return
37+
38+
qweight_row, qweight_col = 128, 16
39+
device = torch.device("cuda")
40+
qweight = torch.randint(
41+
0,
42+
torch.iinfo(torch.int32).max,
43+
(qweight_row, qweight_col),
44+
dtype=torch.int32,
45+
device=device,
46+
)
47+
group_size = qweight_row
48+
scales_row = qweight_row // group_size
49+
scales_col = qweight_col * 8
50+
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
51+
qzeros = torch.randint(
52+
0,
53+
torch.iinfo(torch.int32).max,
54+
(scales_row, qweight_col),
55+
dtype=torch.int32,
56+
device=device,
57+
)
58+
59+
jit_out = jit_awq_dequantize(qweight, scales, qzeros)
60+
aot_out = aot_awq_dequantize(qweight, scales, qzeros)
61+
torch.cuda.synchronize()
62+
torch.testing.assert_close(jit_out, aot_out, rtol=0, atol=0)
63+
print("Correctness check passed (JIT vs AOT)")
64+
65+
66+
if AOT_AVAILABLE:
67+
line_vals = ["jit", "aot"]
68+
line_names = ["JIT Kernel", "AOT Kernel"]
69+
styles = [("blue", "-"), ("green", "-")]
70+
else:
71+
line_vals = ["jit"]
72+
line_names = ["JIT Kernel"]
73+
styles = [("blue", "-")]
74+
75+
76+
@triton.testing.perf_report(
77+
triton.testing.Benchmark(
78+
x_names=["qweight_row", "qweight_col"],
79+
x_vals=configs,
80+
line_arg="provider",
81+
line_vals=line_vals,
82+
line_names=line_names,
83+
styles=styles,
84+
ylabel="us",
85+
plot_name="awq-dequantize-jit-vs-aot",
86+
args={},
87+
)
88+
)
89+
def benchmark(qweight_row, qweight_col, provider):
90+
device = torch.device("cuda")
91+
qweight = torch.randint(
92+
0,
93+
torch.iinfo(torch.int32).max,
94+
(qweight_row, qweight_col),
95+
dtype=torch.int32,
96+
device=device,
97+
)
98+
group_size = qweight_row
99+
scales_row = qweight_row // group_size
100+
scales_col = qweight_col * 8
101+
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
102+
qzeros = torch.randint(
103+
0,
104+
torch.iinfo(torch.int32).max,
105+
(scales_row, qweight_col),
106+
dtype=torch.int32,
107+
device=device,
108+
)
109+
110+
quantiles = [0.5, 0.2, 0.8]
111+
112+
if provider == "jit":
113+
fn = lambda: jit_awq_dequantize(qweight, scales, qzeros)
114+
elif provider == "aot":
115+
fn = lambda: aot_awq_dequantize(qweight, scales, qzeros)
116+
else:
117+
raise ValueError(f"Unknown provider: {provider}")
118+
119+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
120+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
121+
122+
123+
if __name__ == "__main__":
124+
check_correctness()
125+
benchmark.run(print_data=True)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
3+
import numpy as np
4+
import torch
5+
import triton
6+
import triton.testing
7+
from sgl_kernel.scalar_type import scalar_types
8+
9+
from sglang.jit_kernel.awq_marlin_repack import (
10+
awq_marlin_moe_repack as jit_awq_marlin_moe_repack,
11+
)
12+
from sglang.srt.layers.quantization.utils import pack_cols, quantize_weights
13+
14+
try:
15+
from sgl_kernel import awq_marlin_moe_repack as aot_awq_marlin_moe_repack
16+
17+
AOT_AVAILABLE = True
18+
except ImportError:
19+
AOT_AVAILABLE = False
20+
21+
IS_CI = (
22+
os.getenv("CI", "false").lower() == "true"
23+
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
24+
)
25+
26+
# Fixed parameters
27+
NUM_BITS = 4
28+
GROUP_SIZE = 128
29+
SIZE_N = 4096
30+
31+
32+
def awq_pack(q_w, num_bits, size_k, size_n):
33+
if num_bits == 4:
34+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
35+
elif num_bits == 8:
36+
interleave = np.array([0, 2, 1, 3])
37+
else:
38+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
39+
40+
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
41+
q_w = q_w.reshape((-1, size_n)).contiguous()
42+
return pack_cols(q_w, num_bits, size_k, size_n)
43+
44+
45+
def make_moe_weights(num_experts, size_k, size_n, num_bits, group_size):
46+
pack_factor = 32 // num_bits
47+
b_q_weight = torch.empty(
48+
(num_experts, size_k, size_n // pack_factor),
49+
dtype=torch.int32,
50+
device="cuda",
51+
)
52+
for e in range(num_experts):
53+
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
54+
w_ref, q_w, s, zp = quantize_weights(
55+
b_weight, scalar_types.uint4, min(group_size, size_k), zero_points=True
56+
)
57+
b_q_weight[e] = awq_pack(q_w, num_bits, size_k, size_n)
58+
perm = torch.empty((num_experts, 0), dtype=torch.int32, device="cuda")
59+
return b_q_weight, perm
60+
61+
62+
def check_correctness():
63+
if not AOT_AVAILABLE:
64+
print("sgl_kernel AOT not available, skipping correctness check")
65+
return
66+
67+
num_experts = 4
68+
size_k = 1024
69+
b_q_weight, perm = make_moe_weights(
70+
num_experts, size_k, SIZE_N, NUM_BITS, GROUP_SIZE
71+
)
72+
73+
out_jit = jit_awq_marlin_moe_repack(b_q_weight, perm, size_k, SIZE_N, NUM_BITS)
74+
out_aot = aot_awq_marlin_moe_repack(b_q_weight, perm, size_k, SIZE_N, NUM_BITS)
75+
torch.cuda.synchronize()
76+
torch.testing.assert_close(out_jit, out_aot, rtol=0, atol=0)
77+
print("Correctness check passed (JIT vs AOT)")
78+
79+
80+
if IS_CI:
81+
expert_range = [2, 4]
82+
else:
83+
expert_range = [2, 4, 8, 16]
84+
85+
if AOT_AVAILABLE:
86+
line_vals = ["jit", "aot"]
87+
line_names = ["JIT Kernel", "AOT Kernel"]
88+
styles = [("blue", "-"), ("green", "-")]
89+
else:
90+
line_vals = ["jit"]
91+
line_names = ["JIT Kernel"]
92+
styles = [("blue", "-")]
93+
94+
95+
@triton.testing.perf_report(
96+
triton.testing.Benchmark(
97+
x_names=["num_experts"],
98+
x_vals=expert_range,
99+
line_arg="provider",
100+
line_vals=line_vals,
101+
line_names=line_names,
102+
styles=styles,
103+
ylabel="us",
104+
plot_name="awq-marlin-moe-repack-performance",
105+
args={"size_k": 4096, "size_n": SIZE_N, "num_bits": NUM_BITS},
106+
)
107+
)
108+
def benchmark(num_experts, size_k, size_n, num_bits, provider):
109+
group_size = min(GROUP_SIZE, size_k)
110+
b_q_weight, perm = make_moe_weights(
111+
num_experts, size_k, size_n, num_bits, group_size
112+
)
113+
114+
quantiles = [0.5, 0.2, 0.8]
115+
116+
if provider == "jit":
117+
fn = lambda: jit_awq_marlin_moe_repack(
118+
b_q_weight, perm, size_k, size_n, num_bits
119+
)
120+
elif provider == "aot":
121+
fn = lambda: aot_awq_marlin_moe_repack(
122+
b_q_weight, perm, size_k, size_n, num_bits
123+
)
124+
else:
125+
raise ValueError(f"Unknown provider: {provider}")
126+
127+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
128+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
129+
130+
131+
if __name__ == "__main__":
132+
check_correctness()
133+
benchmark.run(print_data=True)

0 commit comments

Comments
 (0)