Skip to content

Commit 0b97584

Browse files
committed
minor opt for cast input
1 parent a995bdd commit 0b97584

1 file changed

Lines changed: 37 additions & 20 deletions

File tree

sgl-kernel/csrc/cpu/gemm.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ inpu
8484
}
8585
}
8686

87+
template <typename scalar_t>
88+
inline void copy_stub(float* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
89+
using bVec = at::vec::Vectorized<scalar_t>;
90+
using fVec = at::vec::Vectorized<float>;
91+
constexpr int kVecSize = bVec::size();
92+
93+
int64_t d;
94+
#pragma GCC unroll 4
95+
for (d = 0; d <= size - kVecSize; d += kVecSize) {
96+
fVec data0, data1;
97+
bVec b_vec = bVec::loadu(input + d);
98+
std::tie(data0, data1) = at::vec::convert_to_float(b_vec);
99+
data0.store(out + d);
100+
data1.store(out + d + fVec::size());
101+
}
102+
for (; d < size; ++d) {
103+
out[d] = static_cast<float>(input[d]);
104+
}
105+
}
106+
87107
template <typename scalar_t>
88108
inline void copy_add_stub(
89109
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
@@ -281,15 +301,6 @@ struct brgemm {
281301
int64_t ldc) {
282302
constexpr int BLOCK_N = block_size_n();
283303
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
284-
285-
// copy from Ctmp to C
286-
for (int64_t m = 0; m < M; ++m) {
287-
if constexpr (has_bias) {
288-
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
289-
} else {
290-
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
291-
}
292-
}
293304
}
294305
};
295306

@@ -461,7 +472,7 @@ void weight_packed_linear_kernel_impl(
461472
template <typename scalar_t>
462473
void weight_packed_linear_kernel_impl(
463474
scalar_t* __restrict__ out,
464-
const float* __restrict__ mat1,
475+
const scalar_t* __restrict__ mat1,
465476
const float* __restrict__ mat2,
466477
const float* __restrict__ bias,
467478
const scalar_t* __restrict__ post_mul_mat,
@@ -476,21 +487,23 @@ void weight_packed_linear_kernel_impl(
476487
const int64_t NB = div_up(N, BLOCK_N);
477488

478489
const bool use_brgemm = true; // TODO: add intrinsic path
479-
480490
// parallel on [MB, NB]
481491
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
482492
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
483493
// for brgemm, use float32 for accumulate
494+
alignas(64) float Atmp[BLOCK_M * K];
484495
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
485496

486497
loop_2d<float>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
487498
int64_t mb_start = mb * BLOCK_M;
488499
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
489500
int64_t nb_start = nb * BLOCK_N;
490501
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
491-
502+
for (int64_t m = 0; m < mb_size; ++m) {
503+
copy_stub<scalar_t>(Atmp + m * K, mat1 + mb_start * mat1_strideM + m * K, K);
504+
}
492505
tinygemm_kernel<scalar_t, has_bias>(
493-
/* A */ mat1 + mb_start * mat1_strideM,
506+
/* A */ Atmp,
494507
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
495508
/* C */ out + mb_start * out_strideM + nb_start,
496509
/* Ctmp*/ Ctmp,
@@ -512,6 +525,15 @@ void weight_packed_linear_kernel_impl(
512525
post_mul_mat + mb_start * out_strideM + m * out_strideM,
513526
out_strideM);
514527
}
528+
} else {
529+
for (int64_t m = 0; m < mb_size; ++m) {
530+
if constexpr (has_bias) {
531+
copy_add_stub(
532+
out + mb_start * out_strideM + nb_start + m * out_strideM, Ctmp + m * BLOCK_N, bias + nb_start, N);
533+
} else {
534+
copy_stub(out + mb_start * out_strideM + nb_start + m * out_strideM, Ctmp + m * BLOCK_N, N);
535+
}
536+
}
515537
}
516538
});
517539

@@ -657,10 +679,6 @@ weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at:
657679
int64_t out_strideM = out.stride(0);
658680
int64_t mat1_strideM = mat1.stride(0);
659681

660-
if (use_fma_gemm) {
661-
mat1 = mat1.to(at::kFloat);
662-
}
663-
664682
const bool has_bias = bias.has_value();
665683
const float* bias_data = nullptr;
666684
if (has_bias) {
@@ -672,7 +690,7 @@ weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at:
672690
if (use_fma_gemm) {
673691
weight_packed_linear_kernel_impl<scalar_t>(
674692
out.data_ptr<scalar_t>(),
675-
mat1.data_ptr<float>(),
693+
mat1.data_ptr<scalar_t>(),
676694
packed_w.data_ptr<float>(),
677695
bias_data,
678696
nullptr,
@@ -728,7 +746,6 @@ at::Tensor fused_linear_sigmoid_mul(
728746
int64_t mat1_strideM = mat1.stride(0);
729747
auto dispatch_type = mat1.scalar_type();
730748
auto out = at::empty({M, out_strideM}, mat1.options());
731-
mat1 = mat1.to(at::kFloat);
732749

733750
TORCH_CHECK(
734751
N == 1 && out_strideM % 32 == 0,
@@ -744,7 +761,7 @@ at::Tensor fused_linear_sigmoid_mul(
744761
AT_DISPATCH_REDUCED_FLOATING_TYPES(dispatch_type, "fused_linear_sigmoid_mul", [&] {
745762
weight_packed_linear_kernel_impl<scalar_t>(
746763
out.data_ptr<scalar_t>(),
747-
mat1.data_ptr<float>(),
764+
mat1.data_ptr<scalar_t>(),
748765
packed_w.data_ptr<float>(),
749766
bias_data,
750767
post_mul_mat.data_ptr<scalar_t>(),

0 commit comments

Comments
 (0)