@@ -84,6 +84,26 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ inpu
8484 }
8585}
8686
87+ template <typename scalar_t >
88+ inline void copy_stub (float * __restrict__ out, const scalar_t * __restrict__ input, int64_t size) {
89+ using bVec = at::vec::Vectorized<scalar_t >;
90+ using fVec = at::vec::Vectorized<float >;
91+ constexpr int kVecSize = bVec::size ();
92+
93+ int64_t d;
94+ #pragma GCC unroll 4
95+ for (d = 0 ; d <= size - kVecSize ; d += kVecSize ) {
96+ fVec data0, data1;
97+ bVec b_vec = bVec::loadu (input + d);
98+ std::tie (data0, data1) = at::vec::convert_to_float (b_vec);
99+ data0.store (out + d);
100+ data1.store (out + d + fVec::size ());
101+ }
102+ for (; d < size; ++d) {
103+ out[d] = static_cast <float >(input[d]);
104+ }
105+ }
106+
87107template <typename scalar_t >
88108inline void copy_add_stub (
89109 scalar_t * __restrict__ out, const float * __restrict__ input, const float * __restrict__ bias, int64_t size) {
@@ -281,15 +301,6 @@ struct brgemm {
281301 int64_t ldc) {
282302 constexpr int BLOCK_N = block_size_n ();
283303 at::native::cpublas::brgemm (M, N, K, lda, ldb, BLOCK_N, /* add_C */ false , A, B, Ctmp);
284-
285- // copy from Ctmp to C
286- for (int64_t m = 0 ; m < M; ++m) {
287- if constexpr (has_bias) {
288- copy_add_stub (C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
289- } else {
290- copy_stub (C + m * ldc, Ctmp + m * BLOCK_N, N);
291- }
292- }
293304 }
294305};
295306
@@ -461,7 +472,7 @@ void weight_packed_linear_kernel_impl(
461472template <typename scalar_t >
462473void weight_packed_linear_kernel_impl (
463474 scalar_t * __restrict__ out,
464- const float * __restrict__ mat1,
475+ const scalar_t * __restrict__ mat1,
465476 const float * __restrict__ mat2,
466477 const float * __restrict__ bias,
467478 const scalar_t * __restrict__ post_mul_mat,
@@ -476,21 +487,23 @@ void weight_packed_linear_kernel_impl(
476487 const int64_t NB = div_up (N, BLOCK_N);
477488
478489 const bool use_brgemm = true ; // TODO: add intrinsic path
479-
480490 // parallel on [MB, NB]
481491 AT_DISPATCH_BOOL (bias != nullptr , has_bias, [&] {
482492 parallel_2d (MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
483493 // for brgemm, use float32 for accumulate
494+ alignas (64 ) float Atmp[BLOCK_M * K];
484495 alignas (64 ) float Ctmp[BLOCK_M * BLOCK_N];
485496
486497 loop_2d<float >(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
487498 int64_t mb_start = mb * BLOCK_M;
488499 int64_t mb_size = std::min (M - mb_start, BLOCK_M);
489500 int64_t nb_start = nb * BLOCK_N;
490501 int64_t nb_size = std::min (N - nb_start, BLOCK_N);
491-
502+ for (int64_t m = 0 ; m < mb_size; ++m) {
503+ copy_stub<scalar_t >(Atmp + m * K, mat1 + mb_start * mat1_strideM + m * K, K);
504+ }
492505 tinygemm_kernel<scalar_t , has_bias>(
493- /* A */ mat1 + mb_start * mat1_strideM ,
506+ /* A */ Atmp ,
494507 /* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */ ,
495508 /* C */ out + mb_start * out_strideM + nb_start,
496509 /* Ctmp*/ Ctmp,
@@ -512,6 +525,15 @@ void weight_packed_linear_kernel_impl(
512525 post_mul_mat + mb_start * out_strideM + m * out_strideM,
513526 out_strideM);
514527 }
528+ } else {
529+ for (int64_t m = 0 ; m < mb_size; ++m) {
530+ if constexpr (has_bias) {
531+ copy_add_stub (
532+ out + mb_start * out_strideM + nb_start + m * out_strideM, Ctmp + m * BLOCK_N, bias + nb_start, N);
533+ } else {
534+ copy_stub (out + mb_start * out_strideM + nb_start + m * out_strideM, Ctmp + m * BLOCK_N, N);
535+ }
536+ }
515537 }
516538 });
517539
@@ -657,10 +679,6 @@ weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at:
657679 int64_t out_strideM = out.stride (0 );
658680 int64_t mat1_strideM = mat1.stride (0 );
659681
660- if (use_fma_gemm) {
661- mat1 = mat1.to (at::kFloat );
662- }
663-
664682 const bool has_bias = bias.has_value ();
665683 const float * bias_data = nullptr ;
666684 if (has_bias) {
@@ -672,7 +690,7 @@ weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at:
672690 if (use_fma_gemm) {
673691 weight_packed_linear_kernel_impl<scalar_t >(
674692 out.data_ptr <scalar_t >(),
675- mat1.data_ptr <float >(),
693+ mat1.data_ptr <scalar_t >(),
676694 packed_w.data_ptr <float >(),
677695 bias_data,
678696 nullptr ,
@@ -728,7 +746,6 @@ at::Tensor fused_linear_sigmoid_mul(
728746 int64_t mat1_strideM = mat1.stride (0 );
729747 auto dispatch_type = mat1.scalar_type ();
730748 auto out = at::empty ({M, out_strideM}, mat1.options ());
731- mat1 = mat1.to (at::kFloat );
732749
733750 TORCH_CHECK (
734751 N == 1 && out_strideM % 32 == 0 ,
@@ -744,7 +761,7 @@ at::Tensor fused_linear_sigmoid_mul(
744761 AT_DISPATCH_REDUCED_FLOATING_TYPES (dispatch_type, " fused_linear_sigmoid_mul" , [&] {
745762 weight_packed_linear_kernel_impl<scalar_t >(
746763 out.data_ptr <scalar_t >(),
747- mat1.data_ptr <float >(),
764+ mat1.data_ptr <scalar_t >(),
748765 packed_w.data_ptr <float >(),
749766 bias_data,
750767 post_mul_mat.data_ptr <scalar_t >(),
0 commit comments