Skip to content

[Quantization][Perf] add triton w8a8 int8 gemm kernel#10502

Closed
ZelinMa557 wants to merge 8 commits intosgl-project:mainfrom
ZelinMa557:w8a8
Closed

[Quantization][Perf] add triton w8a8 int8 gemm kernel#10502
ZelinMa557 wants to merge 8 commits intosgl-project:mainfrom
ZelinMa557:w8a8

Conversation

@ZelinMa557
Copy link
Copy Markdown

@ZelinMa557 ZelinMa557 commented Sep 16, 2025

Motivation

At first I added this triton kernel for nvidia sm89 GPUs since the performance of the cutlass kernel was poor, and then @HydraQYH helped to improve the performance of the cutlass kernel.

I think maybe this kernel can be kept for hip platform, since sglang do not support w8a8 int8 on hip now.

Modifications

Add a triton w8a8 int8 gemm kernel, and only execute it on hip platform.

Accuracy Tests

I add an unit test in test/srt/quant/test_int8_kernel.py

Benchmarking and Profiling

Benchmark with the following script:

import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import int8_scaled_mm
# from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.int8_kernel import w8a8_per_channel_per_token_matmul

def to_int8(tensor: torch.Tensor) -> torch.Tensor:
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


WEIGHT_SHAPES = {
    "meta-llama/Llama-3.1-8B-Instruct": [
        ([4096, 6144], 1),
        ([4096, 4096], 0),
        ([4096, 28672], 1),
        ([14336, 4096], 0),
    ],
    "meta-llama/Llama-3.3-70B-Instruct": [
        ([8192, 10240], 1),
        ([8192, 8192], 0),
        ([8192, 57344], 1),
        ([28672, 8192], 0),
    ],
    "mistralai/Mistral-Large-Instruct-2407": [
        ([12288, 14336], 1),
        ([12288, 12288], 0),
        ([12288, 57344], 1),
        ([28672, 12288], 0),
    ],
    "Qwen/Qwen2.5-7B-Instruct": [
        ([3584, 4608], 1),
        ([3584, 3584], 0),
        ([3584, 37888], 1),
        ([18944, 3584], 0),
    ],
    "Qwen/Qwen2.5-32B-Instruct": [
        ([5120, 7168], 1),
        ([5120, 5120], 0),
        ([5120, 55296], 1),
        ([27648, 5120], 0),
    ],
    "Qwen/Qwen2.5-72B-Instruct": [
        ([8192, 10240], 1),
        ([8192, 8192], 0),
        ([8192, 59136], 1),
        ([29568, 8192], 0),
    ],
    "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
        ([2048, 3072], 1),
        ([2048, 4096], 1),
        ([2048, 2048], 0),
        ([2048, 576], 0),
        ([2048, 21888], 1),
        ([10944, 2048], 0),
        ([2048, 2816], 1),
        ([1408, 2048], 0),
    ],
}


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
        x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
        x_log=False,
        line_arg="provider",
        # line_vals=["vllm", "sgl-kernel", "sglang-triton"],
        # line_names=["vllm int8 gemm", "sgl-kernel int8 gemm", "sglang triton int8 gemm"],
        line_vals=["sgl-kernel", "sglang-triton"],
        line_names=["sgl-kernel int8 gemm", "sglang triton int8 gemm"],
        styles=[("orange", "-"), ("blue", "-")],
        ylabel="GB/s",
        plot_name="int8 scaled matmul",
        args={},
    )
)
def benchmark(batch_size, provider, N, K):
    M = batch_size
    a = to_int8(torch.randn((M, K), device="cuda") * 5)
    origin_b = torch.randn((N, K), device="cuda")
    b_t = to_int8(origin_b.t() * 5)
    b = to_int8(origin_b * 5)
    scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
    scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
    bias = torch.randn((N,), device="cuda", dtype=torch.float16)

    quantiles = [0.5, 0.2, 0.8]
    if provider == "sgl-kernel":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: int8_scaled_mm(a, b_t, scale_a, scale_b, torch.float16, bias),
            quantiles=quantiles,
        )
    # if provider == "vllm":
    #     ms, min_ms, max_ms = triton.testing.do_bench(
    #         lambda: w8a8_per_channel_per_token_matmul(a, b_t, scale_a, scale_b, torch.float16, bias),
    #         quantiles=quantiles,
    #     )
    if provider == "sglang-triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: w8a8_per_channel_per_token_matmul(a, b, scale_a, scale_b, bias, torch.float16),
            quantiles=quantiles,
        )
    gbps = (
        lambda ms: (
            (2 * M * N * K - M * N) * a.element_size()
            + (3 * M * N) * scale_a.element_size()
        )
        * 1e-9
        / (ms * 1e-3)
    )
    return gbps(ms), gbps(max_ms), gbps(min_ms)


def prepare_shapes(args):
    KN_model_names = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        assert model in WEIGHT_SHAPES
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KN.append(model)
            KN_model_names.append(KN)
    return KN_model_names


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=["meta-llama/Llama-3.1-8B-Instruct"],
        help="List of models to benchmark",
    )
    parser.add_argument(
        "--tp-sizes",
        nargs="+",
        type=int,
        default=[1],
        help="List of tensor parallel sizes",
    )
    args = parser.parse_args()

    KN_model_names = prepare_shapes(args)
    for K, N, model_name in KN_model_names:
        print(f"{model_name} N={N} K={K}: ")
        benchmark.run(
            print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
        )

    print("Benchmark finished!")

Benchmark result:

meta-llama/Llama-3.1-8B-Instruct N=6144 K=4096: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1171.857152               984.360025
1         16.0          18749.714435             18313.674416
2         32.0          37499.428869             35794.908811
3         64.0          73254.697666             71488.271011
4        128.0         136954.431744            139900.710535
5        256.0         233600.149987            203222.712426
6        512.0         183938.809032            254541.583123
7       1024.0         203222.712426            298220.296609
8       2048.0         212654.987165            313038.701128
9       4096.0         218866.888741            320503.851694
10      8192.0         219843.978772            329676.088496
meta-llama/Llama-3.1-8B-Instruct N=4096 K=4096: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1093.733361               886.810772
1         16.0          17499.733775             16935.226036
2         32.0          34602.974207             33734.425126
3         64.0          66997.985131             63695.711810
4        128.0         127270.786885            123527.527245
5        256.0         182605.908991            209996.805305
6        512.0         168234.019020            247510.039638
7       1024.0         194111.501956            282459.901958
8       2048.0         208046.374564            290904.660318
9       4096.0         215058.239002            319235.037134
10      8192.0         219383.256860            319614.629843
meta-llama/Llama-3.1-8B-Instruct N=28672 K=4096: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1328.613108              1228.256699
1         16.0          21242.451716             20999.680749
2         32.0          42240.735463             41760.726947
3         64.0          83049.584531             83521.453893
4        128.0         160653.292718            161535.996977
5        256.0         295032.743741            276052.122870
6        512.0         211507.574242            312761.182444
7       1024.0         218178.505056            316576.316501
8       2048.0         220841.713384            318064.697233
9       4096.0         221312.508191            308859.386601
10      8192.0         219694.249512            309570.791732
meta-llama/Llama-3.1-8B-Instruct N=4096 K=14336: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1288.671094              1220.553156
1         16.0          20396.799691             20172.658963
2         32.0          40715.842841             39557.429897
3         64.0          79194.855206             77293.134355
4        128.0         148340.367929            148340.367929
5        256.0         201173.909857            238791.794666
6        512.0         193860.579117            255230.019827
7       1024.0         209340.055815            291595.305008
8       2048.0         215922.702206            295933.423521
9       4096.0         219189.489743            310192.916592
10      8192.0         217118.257755            306400.821909
Benchmark finished!
(pytorch) root@VM-50-185-ubuntu:~/maxinkai/sglang# python3 sgl-kernel/benchmark/bench_int8_gemm.py
meta-llama/Llama-3.1-8B-Instruct N=6144 K=4096: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1171.857152               984.360025
1         16.0          18749.714435             18313.674416
2         32.0          37499.428869             35794.908811
3         64.0          73254.697666             71488.271011
4        128.0         136954.431744            139900.710535
5        256.0         233600.149987            203222.712426
6        512.0         183938.809032            254541.583123
7       1024.0         203222.712426            298220.296609
8       2048.0         212654.987165            313038.701128
9       4096.0         218866.888741            320503.851694
10      8192.0         219843.978772            329676.088496
meta-llama/Llama-3.1-8B-Instruct N=4096 K=4096: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1093.733361               886.810772
1         16.0          17499.733775             16935.226036
2         32.0          34602.974207             33734.425126
3         64.0          66997.985131             63695.711810
4        128.0         127270.786885            123527.527245
5        256.0         182605.908991            209996.805305
6        512.0         168234.019020            247510.039638
7       1024.0         194111.501956            282459.901958
8       2048.0         208046.374564            290904.660318
9       4096.0         215058.239002            319235.037134
10      8192.0         219383.256860            319614.629843
meta-llama/Llama-3.1-8B-Instruct N=28672 K=4096: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1328.613108              1228.256699
1         16.0          21242.451716             20999.680749
2         32.0          42240.735463             41760.726947
3         64.0          83049.584531             83521.453893
4        128.0         160653.292718            161535.996977
5        256.0         295032.743741            276052.122870
6        512.0         211507.574242            312761.182444
7       1024.0         218178.505056            316576.316501
8       2048.0         220841.713384            318064.697233
9       4096.0         221312.508191            308859.386601
10      8192.0         219694.249512            309570.791732
meta-llama/Llama-3.1-8B-Instruct N=4096 K=14336: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1288.671094              1220.553156
1         16.0          20396.799691             20172.658963
2         32.0          40715.842841             39557.429897
3         64.0          79194.855206             77293.134355
4        128.0         148340.367929            148340.367929
5        256.0         201173.909857            238791.794666
6        512.0         193860.579117            255230.019827
7       1024.0         209340.055815            291595.305008
8       2048.0         215922.702206            295933.423521
9       4096.0         219189.489743            310192.916592
10      8192.0         217118.257755            306400.821909
Benchmark finished!
(pytorch) root@VM-50-185-ubuntu:~/maxinkai/sglang# python3 sgl-kernel/benchmark/bench_int8_gemm.py --models Qwen/Qwen2.5-72B-Instruct --tp-sizes 4 8
Qwen/Qwen2.5-72B-Instruct N=2560 K=8192: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1138.541626               853.906220
1         16.0          18216.666022             14593.602158
2         32.0          35598.981780             31112.823682
3         64.0          69031.582013             53534.695627
4        128.0         102870.590100            107069.391255
5        256.0         197977.359415            197977.359415
6        512.0         177844.062420            262320.006627
7       1024.0         193415.674062            297668.091858
8       2048.0         207778.222322            305592.343065
9       4096.0         216625.542165            311474.599048
10      8192.0         219170.753765            314391.016180
Qwen/Qwen2.5-72B-Instruct N=8192 K=2048: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1026.750003               909.902245
1         16.0          16428.000041             17523.200443
2         32.0          32856.000083             35046.400885
3         64.0          62333.571881             67831.742853
4        128.0         110219.637802            131424.000332
5        256.0         179198.633112            215670.161306
6        512.0         173425.490868            258804.183911
7       1024.0         188155.443204            280371.207083
8       2048.0         197328.700166            307607.249934
9       4096.0         199721.815507            319663.122641
10      8192.0         203290.303219            319639.401782
Qwen/Qwen2.5-72B-Instruct N=14784 K=8192: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1252.395843              1226.439444
1         16.0          19913.217954             21396.865079
2         32.0          39246.062206             40695.713996
3         64.0          74259.707229             82696.582321
4        128.0         131204.373086            157801.869418
5        256.0         206108.573065            279243.879426
6        512.0         211429.137799            300724.159229
7       1024.0         218130.314271            313562.325738
8       2048.0         222064.745620            312955.034507
9       4096.0         223962.737349            309557.691279
10      8192.0         217640.652741            308180.142442
Qwen/Qwen2.5-72B-Instruct N=8192 K=7392: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1220.206230              1066.306295
1         16.0          19324.082264             17225.787319
2         32.0          38257.778885             33517.875550
3         64.0          75000.397664             65302.067338
4        128.0         131739.823164            123171.376846
5        256.0         217986.768667            221118.608676
6        512.0         194856.327703            261208.269354
7       1024.0         208831.591413            293464.018253
8       2048.0         215276.455405            304142.127216
9       4096.0         217412.442438            303010.481538
10      8192.0         215416.933584            297698.849703
Qwen/Qwen2.5-72B-Instruct N=1280 K=8192: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0            788.221163               640.429689
1         16.0          12611.538604             10207.003614
2         32.0          25223.077207             16395.000414
3         64.0          46634.667468             33738.905592
4        128.0          52464.001325             55812.764343
5        256.0         107069.391255            114052.171382
6        512.0         154305.879647            201784.617659
7       1024.0         176349.584826            265956.129171
8       2048.0         192528.439447            289456.555405
9       4096.0         207265.185767            304139.123686
10      8192.0         213900.045015            304147.735010
Qwen/Qwen2.5-72B-Instruct N=8192 K=1024: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0            866.947407               716.173897
1         16.0          13724.432370             14641.777260
2         32.0          26479.321023             29182.229515
3         64.0          50200.381351             55484.634043
4        128.0          87850.663559            110969.268086
5        256.0         144628.749805            183340.517671
6        512.0         130186.808588            210021.203732
7       1024.0         149268.386904            255263.413067
8       2048.0         157753.757541            280610.606775
9       4096.0         161440.008331            297221.642941
10      8192.0         163959.449854            307377.278552
Qwen/Qwen2.5-72B-Instruct N=7392 K=8192: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1220.117635              1207.667450
1         16.0          19378.291910             20361.531737
2         32.0          38645.358406             39450.467355
3         64.0          73538.738943             78900.934709
4        128.0         127302.356546            148337.626762
5        256.0         201986.390406            233117.281417
6        512.0         199986.529381            288080.448420
7       1024.0         211504.084681            306088.307369
8       2048.0         218757.823085            311547.151778
9       4096.0         221554.122014            308572.481710
10      8192.0         217628.436525            301378.522024
Qwen/Qwen2.5-72B-Instruct N=8192 K=3696: 
int8 scaled matmul:
    batch_size  sgl-kernel int8 gemm  sglang triton int8 gemm
0          1.0           1117.433967               999.297651
1         16.0          17878.943480             16660.817537
2         32.0          35757.886959             31586.134131
3         64.0          68642.193202             64242.981109
4        128.0         118332.440395            113144.356767
5        256.0         178303.200483            166608.172706
6        512.0         131267.048661            178368.753630
7       1024.0         137737.465696            202151.243732
8       2048.0         140545.484201            207052.579753
9       4096.0         140879.623964            208026.368151
10      8192.0         140708.531250            208045.881273
Benchmark finished!

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ZelinMa557, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of int8 quantization within the system by integrating a new, highly optimized Triton kernel for weight-8-bit, activation-8-bit (w8a8) General Matrix Multiply (GEMM) operations. This specialized kernel is designed to leverage the capabilities of modern NVIDIA GPUs (sm89 architecture) and provides substantial speed improvements, particularly for larger batch sizes, making quantized model inference more efficient.

Highlights

  • New Triton Kernel: Introduced a Triton-based w8a8 int8 GEMM kernel for improved performance in quantized operations.
  • Performance Boost: Achieves up to 50% speedup over the existing sgl-kernel for matrix sizes where M (batch size) is greater than or equal to 1024.
  • Hardware Specific Optimization: The new kernel is specifically enabled for NVIDIA GPUs with compute capability sm89 (e.g., RTX 4090) to leverage their architectural advantages.
  • Accuracy Verified: Includes new unit tests to ensure the accuracy of the implemented int8 GEMM kernel against a native PyTorch reference.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel for w8a8 int8 GEMM, aimed at improving performance on sm89 architectures. The benchmarks provided demonstrate a significant speedup for larger matrix sizes, which is a great addition. The implementation is well-structured, and the inclusion of a new unit test ensures the correctness of the new kernel. My review includes a few suggestions to enhance code clarity, maintainability, and the robustness of the tests.

Comment on lines +497 to +500
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conditional block for converting the accumulator to the output dtype can be simplified. Using C.dtype.element_ty directly makes the code more concise and adaptable to future support for other dtypes.

    c = accumulator.to(C.dtype.element_ty)

triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
bias_ptr = bias if bias is not None else B
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using the large weight tensor B as a dummy pointer when bias is None is not ideal for clarity and could have minor performance implications. Since As is a required argument and is smaller than B, it would be a better choice for a dummy tensor.

    bias_ptr = bias if bias is not None else As

Comment on lines +210 to +214
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.05
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current assertion checks the mean relative error, which might not catch large discrepancies in a small number of elements. Using torch.allclose provides a more robust, element-wise comparison and is generally preferred for tensor comparisons in tests. This will ensure that all elements in the output tensor are within an acceptable tolerance.

        self.assertTrue(torch.allclose(out, ref_out, rtol=0.05, atol=1e-1))

Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

fix lint

Copy link
Copy Markdown
Collaborator

@HydraQYH HydraQYH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Could you please reply to my comment? If it is convenient for you, could you provide a more detailed performance analysis report(e.g. ncu) to prove the source of the performance improvement?

Comment thread python/sglang/srt/layers/quantization/int8_kernel.py
Comment thread test/srt/quant/test_int8_kernel.py Outdated
@ZelinMa557
Copy link
Copy Markdown
Author

Hi, thank you for your patient code review! @HydraQYH I will attach some ncu profile results tomorrow.

Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

This is the ncu perf result, ncu indicates that the L2 access pattern of the cutlass might be sub-optimal:
(Input shape: M = 2048, N = 4096, K = 4096)
image
I'm not sure why the cutlass kernel have this issue, but you can notice that the input layouts for cutlass and triton kernels are different:
img_v3_02qa_19c53c93-bad9-49aa-bb7d-c8c291f049fg

I attached the ncu perf file here:
int8_perf.ncu-rep.zip

Also, I have refacted the test code.

Maybe you can take a look? @HydraQYH

@HydraQYH
Copy link
Copy Markdown
Collaborator

@ZelinMa557 I analyzed the ncu report and found that the CUTLASS-based kernel used inefficient IMMA instructions:
image
This is because CUTLASS uses an incorrect configuration, see this PR for details. Could you please help test the triton kernel performance based on the latest fixes?

@ZelinMa557
Copy link
Copy Markdown
Author

Hi, I saw the performance report is the new pr, the performance boost of the fixed cutlass kernel is higher than the triton one, so I think there is no need to benchmark the triton kernel again. @HydraQYH

However, maybe we can keep this triton kernel for other platforms, such as hip?

@HydraQYH
Copy link
Copy Markdown
Collaborator

Hi, I saw the performance report is the new pr, the performance boost of the fixed cutlass kernel is higher than the triton one, so I think there is no need to benchmark the triton kernel again. @HydraQYH

However, maybe we can keep this triton kernel for other platforms, such as hip?

@ZelinMa557 This is a good idea. But you need to adapt your code to be used only on the HIP platform.

Copy link
Copy Markdown
Collaborator

@HydraQYH HydraQYH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapting the HIP platform requires addressing the following comments.

Comment thread python/sglang/srt/layers/quantization/w8a8_int8.py Outdated
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

Adapting the HIP platform requires addressing the following comments.

Thanks, I have updated the code. Do we need to re-tune the configs of this triton kernel on HIP devices?

@HydraQYH
Copy link
Copy Markdown
Collaborator

@zhyncs @BBuf Could you please invite someone who knows about AMD HIP platform to review this PR?

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 25, 2025

@HaiShaw @saienduri may you help take a look? thanks

Signed-off-by: ZelinMa557 <3388706467@qq.com>
Signed-off-by: ZelinMa557 <3388706467@qq.com>
@ZelinMa557
Copy link
Copy Markdown
Author

@HaiShaw hi, can you take a look at this pr?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants