[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32#8212
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @chunyuan-w, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical data type mismatch within the biased_grouped_topk CPU kernel. It ensures that the correction_bias parameter is consistently treated as float32 throughout the kernel implementation and its corresponding unit tests. This change aligns the kernel with a previous fix in a related PR, improving the robustness and correctness of the topk operation.
Highlights
- Data Type Correction: The
correction_biasparameter in thebiased_grouped_topkCPU kernel has been explicitly fixed to usefloat32instead of a genericscalar_t, resolving a data type inconsistency and aligning with a previous related fix. - Kernel Signature and Implementation Update: The
apply_biashelper function and the mainbiased_grouped_topk_kernel_implfunction signatures were updated to directly acceptconst float* bias, simplifying the internal vectorization logic by removing the need forconvert_to_floatoperations. - Type Check Enforcement: The
biased_grouped_topk_cpufunction now explicitly checks that thecorrection_biasinput tensor is of typeat::kFloat, ensuring strict type safety at the API boundary. - Unit Test Alignment: The unit tests for
biased_grouped_topkwere updated to generate thecorrection_biastensor astorch.float32, ensuring consistency with the kernel's new expected data type and removing redundant float conversions.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
|
The CI failure is not related to the change in the current PR. |
| correction_bias.scalar_type(), | ||
| "biased_grouped_topk_kernel2", | ||
| [&] { | ||
| using bias_scalar_t = scalar_t; |
There was a problem hiding this comment.
scalar_t will be dtype from correction_bias in AT_DISPATCH_FLOATING_TYPES_AND2.
yet in LAUNCH_BIASED_GROUPED_TOPK_KERNEL, it has gating_output.data_ptr<scalar_t>()
isn't this going to be problematic?
There was a problem hiding this comment.
Have updated this parts with following suggestions and fixed this issue. Thanks.
|
To make the code change more decent we can do something like this: re-write set |
|
also you can put a help function in then the code will be more neat. |
Got it, will refine the code accorddingly, Thanks! |
mingfeima
left a comment
There was a problem hiding this comment.
generally LGTM! just a few minor changes needed!
| fVec x0 = fVec::loadu(scores + d) + bias0; | ||
| fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1; | ||
| using bVec = at::vec::Vectorized<param_t>; | ||
| auto vec_size = fVec::size() * 2 == bVec::size() ? bVec::size() : fVec::size() * 2; |
There was a problem hiding this comment.
isn't this always 32 on avx512 for vec_size? just use:
template <typename scalar_t, typename param_t, int SIZE>
inline void apply_bias(...) {
using bVec = att::vec::Vectorized<scalar_t>;
}
will do all the work.
There was a problem hiding this comment.
Yes, got the point, and refiend.
| template < | ||
| typename scalar_t, | ||
| typename std::enable_if_t<is_reduced_floating_point_v<scalar_t> || std::is_same_v<scalar_t, float>, int> = 1> |
There was a problem hiding this comment.
| template < | |
| typename scalar_t, | |
| typename std::enable_if_t<is_reduced_floating_point_v<scalar_t> || std::is_same_v<scalar_t, float>, int> = 1> | |
| template <typename scalar_t, | |
| typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> |
| template <> | ||
| inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2<float>(const float* __restrict__ data) { |
There was a problem hiding this comment.
| template <> | |
| inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2<float>(const float* __restrict__ data) { | |
| inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) { |
There was a problem hiding this comment.
let's see if using overload can be compiled. we can save 2 or 3 lines here.
There was a problem hiding this comment.
yes, changed to use overload,
| return std::make_tuple(x0, x1); | ||
| } | ||
|
|
||
| // specification for `float` |
There was a problem hiding this comment.
| // specification for `float` |
| // dispatch mixed precision of TYPE2 with TYPE1: | ||
| // TYPE2: float32 | ||
| // with TYPE1: bfloat16, float16 | ||
| // TYPE2: bfloat16, float16 | ||
| // with TYPE1: same as TYPE2 |
There was a problem hiding this comment.
| // dispatch mixed precision of TYPE2 with TYPE1: | |
| // TYPE2: float32 | |
| // with TYPE1: bfloat16, float16 | |
| // TYPE2: bfloat16, float16 | |
| // with TYPE1: same as TYPE2 | |
| // dispatch with mixed dtypes (TYPE1, TYPE2): | |
| // TYPE1: the primary dtype (input, output, weight); | |
| // TYPE2: the secondary dtype (bias, etc.). |
| } \ | ||
| } \ | ||
| }() | ||
|
|
There was a problem hiding this comment.
good job! exactly what i meant :)
|
Checked CI failures are not related to this PR. |
|
Hi @zhyncs, could you kindly help review this PR? Thanks! |
…loat32 (sgl-project#8212) Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
…loat32 (sgl-project#8212) Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
Motivation
#7825 has fixed the dtype of
e_score_correction_biasto float32 instead of bfloat16. We need to fix the support in the topk kernel.UT has been updated accordingly.