Skip to content

[Roadmap] JIT kernel development #17035

@DarkSharpness

Description

@DarkSharpness

Background

Over the past few months, the SGLang community has jointly landed a JIT-kernel framework that enables high-performance custom CUDA/C++ kernels to be compiled and deployed at runtime. Powered by tvm-ffi, we achieve extremely low compilation overhead while retaining full C++ and CUDA expressiveness, making it suitable for both production deployment and rapid kernel prototyping.

So far, several important kernels have already been upstreamed:

We also provide a basic development guide for adding custom kernels in #14570. This establishes a solid foundation, but we are still in the early phase of building a robust, scalable, and maintainable JIT-kernel ecosystem.

Next Step

Going forward, we plan to expand JIT coverage to more kernels while significantly improving developer experience, code quality, and platform support.

Ongoing PRs

Future Work

  • AMD ROCm support. Since tvm-ffi supports AMD backends, adding ROCm is feasible. With [Refactor] Clean up JIT kernel utilites #16884, most architecture specific instructions have been moved into include directory. Supporting AMD backends would require 1. Implementing the ROCm support in include directory. 2. Performance tuning (without regressing CUDA)
  • Enhance tests on JIT kernel. Recently, Fix wrong kernel selection for int32/int64 indices #16912 fix wrong kernel selection in [Feature] Support JIT set kv cache #16273. This should actually be detected in the unit-test. We definitely need more stronger tests that cover all the edge cases. Every feature must be either implemented or tested.
  • Better documentation. We've introduced bunches of utilities in [Refactor] Clean up JIT kernel utilites #16884. However, most of them are not well documented. We need to properly document them both in C++ doc strings and https://github.com/sgl-project/sglang/blob/main/docs/developer_guide/JIT_kernels.md.
  • Cleaner and simpler code structure. For example ([Refactor] Use is_in_ci() utility in JIT kernel benchmarks #17118), various JIT kernel benchmarks share much redundant code like
    IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
    )
    In this example, we should unify the code with one common utility function
    def is_in_ci():
    return (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
    )
    The same applies to C++ functions. Any new utilites/abstractions that can clarify the code is welcomed.
  • Support/Accelerate more template functions. A recent example is norm template:
    #pragma once
    #include <sgl_kernel/math.cuh>
    #include <sgl_kernel/type.cuh>
    #include <sgl_kernel/utils.cuh>
    #include <sgl_kernel/vec.cuh>
    #include <sgl_kernel/warp.cuh>
    #include <cstdint>
    #include <type_traits>
    namespace host::norm {
    /**
    * \brief Check if the given configuration is supported.
    * \tparam T Element type (only fp16_t/bf16_t is supported)
    * \tparam kDim Dimension size (usually hidden size)
    */
    template <typename T, int64_t kDim>
    inline constexpr bool is_config_supported() {
    if (!std::is_same_v<T, fp16_t> && !std::is_same_v<T, bf16_t>) return false;
    if (kDim <= 256) {
    return (kDim == 64 || kDim == 128 || kDim == 256);
    } else {
    return (kDim % 256 == 0 && kDim <= 8192);
    }
    }
    /**
    * \brief Determine whether to use cta norm based on dimension size.
    * TL;DR: use warp norm for dim <= 256, cta norm otherwise.
    * \tparam T Element type (fp16_t or bf16_t)
    * \tparam kDim Dimension size (usually hidden size)
    * \note This function assumes that the configuration is supported.
    * \see `is_config_supported`
    */
    template <typename T, int64_t kDim>
    inline constexpr bool should_use_cta() {
    static_assert(is_config_supported<T, kDim>(), "Unsupported norm configuration");
    return kDim > 256;
    }
    /**
    * \brief Get the number of threads per CTA for cta norm.
    * \tparam T Element type (fp16_t or bf16_t)
    * \tparam kDim Dimension size (usually hidden size)
    * \return Number of threads per CTA
    */
    template <typename T, int64_t kDim>
    inline constexpr uint32_t get_cta_threads() {
    static_assert(should_use_cta<T, kDim>());
    return (kDim / 256) * device::kWarpThreads;
    }
    } // namespace host::norm
    namespace device::norm {
    namespace details {
    template <int64_t kDim, bool kUseCTA, typename PackedFloat, std::size_t N>
    SGL_DEVICE AlignedVector<PackedFloat, N> apply_norm_impl(
    const AlignedVector<PackedFloat, N> input,
    const AlignedVector<PackedFloat, N> weight,
    const float eps,
    [[maybe_unused]] float* smem_buffer,
    [[maybe_unused]] uint32_t num_warps) {
    float sum_of_squares = 0.0f;
    #pragma unroll
    for (auto i = 0u; i < N; ++i) {
    const auto fp32_input = cast<fp32x2_t>(input[i]);
    sum_of_squares += fp32_input.x * fp32_input.x;
    sum_of_squares += fp32_input.y * fp32_input.y;
    }
    sum_of_squares = warp::reduce_sum(sum_of_squares);
    float norm_factor;
    if constexpr (kUseCTA) {
    // need to synchronize across the cta
    const auto warp_id = threadIdx.x / kWarpThreads;
    smem_buffer[warp_id] = sum_of_squares;
    __syncthreads();
    // use the first warp to reduce
    if (warp_id == 0) {
    const auto tx = threadIdx.x;
    const auto local_sum = tx < num_warps ? smem_buffer[tx] : 0.0f;
    sum_of_squares = warp::reduce_sum(local_sum);
    smem_buffer[32] = math::rsqrt(sum_of_squares / kDim + eps);
    }
    __syncthreads();
    norm_factor = smem_buffer[32];
    } else {
    norm_factor = math::rsqrt(sum_of_squares / kDim + eps);
    }
    AlignedVector<PackedFloat, N> output;
    #pragma unroll
    for (auto i = 0u; i < N; ++i) {
    const auto fp32_input = cast<fp32x2_t>(input[i]);
    const auto fp32_weight = cast<fp32x2_t>(weight[i]);
    output[i] = cast<PackedFloat, fp32x2_t>({
    fp32_input.x * norm_factor * fp32_weight.x,
    fp32_input.y * norm_factor * fp32_weight.y,
    });
    }
    return output;
    }
    } // namespace details
    /**
    * \brief Apply norm using warp-level implementation.
    * \tparam kDim Dimension size
    * \tparam T Element type (fp16_t or bf16_t)
    * \param input Input vector
    * \param weight Weight vector
    * \param eps Epsilon value for numerical stability
    * \return Normalized output vector
    */
    template <int64_t kDim, typename T>
    SGL_DEVICE T apply_norm_warp(const T& input, const T& weight, float eps) {
    static_assert(kDim <= 256, "Warp norm only supports dim <= 256");
    return details::apply_norm_impl<kDim, false>(input, weight, eps, nullptr, 0);
    }
    /**
    * \brief Apply norm using CTA-level implementation.
    * \tparam kDim Dimension size
    * \tparam T Element type (fp16_t or bf16_t)
    * \param input Input vector
    * \param weight Weight vector
    * \param eps Epsilon value for numerical stability
    * \param smem Shared memory buffer
    * \param num_warps Number of warps in the CTA
    * \return Normalized output vector
    */
    template <int64_t kDim, typename T>
    SGL_DEVICE T apply_norm_cta(
    const T& input, const T& weight, float eps, float* smem, uint32_t num_warps = blockDim.x / kWarpThreads) {
    static_assert(kDim > 256, "CTA norm only supports dim > 256");
    return details::apply_norm_impl<kDim, true>(input, weight, eps, smem, num_warps);
    }
    /**
    * \brief Storage type for norm operation.
    * For warp norm, the storage size depends on kDim.
    * For cta norm, the storage size is fixed to 16B.
    * We will also pack the input 16-bit floats into 32-bit types
    * for faster CUDA core operations.
    *
    * \tparam T Element type (fp16_t or bf16_t)
    * \tparam kDim Dimension size
    */
    template <typename T, int64_t kDim>
    using StorageType = std::conditional_t< // storage type
    (kDim > 256), // whether to use cta norm
    AlignedVector<packed_t<T>, 4>, // cta norm storage, fixed to 16B
    AlignedVector<packed_t<T>, kDim / (2 * kWarpThreads)> // warp norm storage
    >;
    /**
    * \brief Minimum shared memory size (in bytes) required for cta norm.
    */
    inline constexpr uint32_t kSmemBufferSize = 33;
    } // namespace device::norm
    This is not a highly tuned baseline, which still has much room for improvement.
  • Unify C++ coding standard.
  • Support 3rdparty libraries, like cutlass.
  • Multi-platform support: Currently, JIT kernel is mostly tested on CUDA 12.9 with gcc 13.3. We may extend the support to CUDA 12.1 with gcc 11.1 or even other compilers/toolchains.

Known Limitations

  • QK norm seems to be not compatible with head_dim 512,1024. We need to add support for them (the following code is wrong):
    if head_dim not in [64, 128, 256, 512, 1024]:
    logger.warning(f"Unsupported head_dim={head_dim} for JIT QK-Norm kernel")
    return False
  • After [Refactor] Clean up JIT kernel utilites #16884, early JIT HiCache kernel seems a little outdated. We may need to keep them updated with our latest code standard. In additional, currently we only enable JIT HiCache kernel for normal layer-first layout, but it's actually compatible with page-first layout without needing to change the kernel.
  • Currently C++ radix tree is still implemented with torch JIT [Feature] Radix Tree in C++ #7369. We may try to rewrite it in our JIT kernel.
  • Add missing acknowledgments to https://github.com/flashinfer-ai/flashinfer. We learn a lot from their design.

We welcome contributors who are interested in these tasks and willing to join us for rapid development. Keep in mind that code quality and performance is equally important.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions