Skip to content

[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32#8212

Merged
zhyncs merged 6 commits intosgl-project:mainfrom
chunyuan-w:chunyuan/correction_bias_dtype
Aug 5, 2025
Merged

[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32#8212
zhyncs merged 6 commits intosgl-project:mainfrom
chunyuan-w:chunyuan/correction_bias_dtype

Conversation

@chunyuan-w
Copy link
Copy Markdown
Contributor

Motivation

#7825 has fixed the dtype of e_score_correction_bias to float32 instead of bfloat16. We need to fix the support in the topk kernel.
UT has been updated accordingly.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @chunyuan-w, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical data type mismatch within the biased_grouped_topk CPU kernel. It ensures that the correction_bias parameter is consistently treated as float32 throughout the kernel implementation and its corresponding unit tests. This change aligns the kernel with a previous fix in a related PR, improving the robustness and correctness of the topk operation.

Highlights

  • Data Type Correction: The correction_bias parameter in the biased_grouped_topk CPU kernel has been explicitly fixed to use float32 instead of a generic scalar_t, resolving a data type inconsistency and aligning with a previous related fix.
  • Kernel Signature and Implementation Update: The apply_bias helper function and the main biased_grouped_topk_kernel_impl function signatures were updated to directly accept const float* bias, simplifying the internal vectorization logic by removing the need for convert_to_float operations.
  • Type Check Enforcement: The biased_grouped_topk_cpu function now explicitly checks that the correction_bias input tensor is of type at::kFloat, ensuring strict type safety at the API boundary.
  • Unit Test Alignment: The unit tests for biased_grouped_topk were updated to generate the correction_bias tensor as torch.float32, ensuring consistency with the kernel's new expected data type and removing redundant float conversions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the biased_grouped_topk CPU kernel to handle correction_bias as a float32 tensor. The C++ code is simplified, and the corresponding Python tests are updated.

Comment thread sgl-kernel/csrc/cpu/topk.cpp Outdated
Comment thread sgl-kernel/csrc/cpu/topk.cpp Outdated
Comment thread sgl-kernel/csrc/cpu/topk.cpp Outdated
@mingfeima mingfeima added cpu cpu backend performance optimization intel sgl-kernel labels Jul 21, 2025
@chunyuan-w
Copy link
Copy Markdown
Contributor Author

The CI failure is not related to the change in the current PR.
Other PRs are having the same failures, for example:
https://github.com/sgl-project/sglang/actions/runs/16427578447/job/46425726939?pr=7135

Comment thread sgl-kernel/csrc/cpu/topk.cpp Outdated
correction_bias.scalar_type(),
"biased_grouped_topk_kernel2",
[&] {
using bias_scalar_t = scalar_t;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar_t will be dtype from correction_bias in AT_DISPATCH_FLOATING_TYPES_AND2.

yet in LAUNCH_BIASED_GROUPED_TOPK_KERNEL, it has gating_output.data_ptr<scalar_t>()

isn't this going to be problematic?

Copy link
Copy Markdown
Contributor

@jianan-gu jianan-gu Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated this parts with following suggestions and fixed this issue. Thanks.

@mingfeima
Copy link
Copy Markdown
Collaborator

mingfeima commented Jul 24, 2025

To make the code change more decent we can do something like this:

re-write AT_DISPATCH_REDUCED_FLOATING_TYPES in common.h, say

set scalar_t and param_t in the MACRO, the logic is similar to mixed_dtype_support in BatchNorm2d in pytorch, let's use the similar naming convention.

// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2 ...)                                    \
  [&] {                                                                         \
    if (TYPE2 == at::kFloat) {
    switch (TYPE1) {                                                             \
      case at::ScalarType::BFloat16 : {                                         \
        using scalar_t = at::BFloat16;                                          \
        using param_t = at::Float;                                        \
        return __VA_ARGS__();                                                   \
      }                                                                         \
      case at::ScalarType::Half: {                                              \
        using scalar_t = at::Half;                                              \
        using param_t = at::Float;             \
        return __VA_ARGS__();                                                   \
      }                                                                         \
     )
      default:                                                                  \
        TORCH_CHECK(false, "Unsupported floating data type.\n");                \
    }                                                                           \
    else {
      TORCH_CHECK(TYPE1 == TYPE2);
      switch (TYPE1) {
        // set `scalar_t` and `param_t`

      }
    }
  }()

@mingfeima
Copy link
Copy Markdown
Collaborator

also you can put a help function in vec.h

// base type
template <typename scalar_t, typename std::enable_if /* allow only f32, f16, bf16*/>
inline std::tuple<Vectorized<float>, Vectorized<float>>  load_float_vec2(const scalar_t* __restrict__ data) {...}

// specification for `float`
inline std::tuple<Vectorized<float>, Vectorized<float>>  load_float_vec2<float>(const float* __restrict__ data) {...}

then the code will be more neat.

@jianan-gu
Copy link
Copy Markdown
Contributor

To make the code change more decent we can do something like this:

re-write AT_DISPATCH_REDUCED_FLOATING_TYPES in common.h, say

set scalar_t and param_t in the MACRO, the logic is similar to mixed_dtype_support in BatchNorm2d in pytorch, let's use the similar naming convention.

// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2 ...)                                    \
  [&] {                                                                         \
    if (TYPE2 == at::kFloat) {
    switch (TYPE1) {                                                             \
      case at::ScalarType::BFloat16 : {                                         \
        using scalar_t = at::BFloat16;                                          \
        using param_t = at::Float;                                        \
        return __VA_ARGS__();                                                   \
      }                                                                         \
      case at::ScalarType::Half: {                                              \
        using scalar_t = at::Half;                                              \
        using param_t = at::Float;             \
        return __VA_ARGS__();                                                   \
      }                                                                         \
     )
      default:                                                                  \
        TORCH_CHECK(false, "Unsupported floating data type.\n");                \
    }                                                                           \
    else {
      TORCH_CHECK(TYPE1 == TYPE2);
      switch (TYPE1) {
        // set `scalar_t` and `param_t`

      }
    }
  }()

Got it, will refine the code accorddingly, Thanks!

Copy link
Copy Markdown
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally LGTM! just a few minor changes needed!

Comment thread sgl-kernel/csrc/cpu/topk.cpp Outdated
fVec x0 = fVec::loadu(scores + d) + bias0;
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
using bVec = at::vec::Vectorized<param_t>;
auto vec_size = fVec::size() * 2 == bVec::size() ? bVec::size() : fVec::size() * 2;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this always 32 on avx512 for vec_size? just use:

template <typename scalar_t, typename param_t, int SIZE>
inline void apply_bias(...) {
  using bVec = att::vec::Vectorized<scalar_t>;
}

will do all the work.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, got the point, and refiend.

Comment thread sgl-kernel/csrc/cpu/vec.h Outdated
Comment on lines +20 to +22
template <
typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t> || std::is_same_v<scalar_t, float>, int> = 1>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <
typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t> || std::is_same_v<scalar_t, float>, int> = 1>
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>

Comment thread sgl-kernel/csrc/cpu/vec.h Outdated
Comment on lines +33 to +34
template <>
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2<float>(const float* __restrict__ data) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <>
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2<float>(const float* __restrict__ data) {
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's see if using overload can be compiled. we can save 2 or 3 lines here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, changed to use overload,

Comment thread sgl-kernel/csrc/cpu/vec.h Outdated
return std::make_tuple(x0, x1);
}

// specification for `float`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// specification for `float`

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment thread sgl-kernel/csrc/cpu/common.h Outdated
Comment on lines +50 to +54
// dispatch mixed precision of TYPE2 with TYPE1:
// TYPE2: float32
// with TYPE1: bfloat16, float16
// TYPE2: bfloat16, float16
// with TYPE1: same as TYPE2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// dispatch mixed precision of TYPE2 with TYPE1:
// TYPE2: float32
// with TYPE1: bfloat16, float16
// TYPE2: bfloat16, float16
// with TYPE1: same as TYPE2
// dispatch with mixed dtypes (TYPE1, TYPE2):
// TYPE1: the primary dtype (input, output, weight);
// TYPE2: the secondary dtype (bias, etc.).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, refined.

} \
} \
}()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good job! exactly what i meant :)

@jianan-gu
Copy link
Copy Markdown
Contributor

Checked CI failures are not related to this PR.

@jianan-gu
Copy link
Copy Markdown
Contributor

Hi @zhyncs, could you kindly help review this PR? Thanks!

@Alcanderian Alcanderian self-assigned this Aug 1, 2025
@Alcanderian Alcanderian added the ready-to-merge The PR is ready to merge after the CI is green. label Aug 1, 2025
@zhyncs zhyncs merged commit 08f8f49 into sgl-project:main Aug 5, 2025
89 of 98 checks passed
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
…loat32 (sgl-project#8212)

Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
…loat32 (sgl-project#8212)

Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpu cpu backend performance optimization intel ready-to-merge The PR is ready to merge after the CI is green. sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants