1717
1818namespace deep_gemm {
1919
20- template <int kNumFormerIters , int kGap , int kEnd >
21- __device__ __host__ void outer_launch_k_iterations (const auto & inner_launch_k_iterations, const auto & func, int num_former_iters) {
20+ template <uint32_t kNumFormerIters , uint32_t kGap , uint32_t kEnd >
21+ __device__ __host__ void outer_launch_k_iterations (const auto & inner_launch_k_iterations, const auto & func, uint32_t num_former_iters) {
2222 if (num_former_iters == kNumFormerIters ) {
2323 inner_launch_k_iterations (func, cute::Int<kNumFormerIters >{});
2424 return ;
@@ -54,7 +54,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
5454 DG_STATIC_ASSERT (BLOCK_M % WGMMA::M == 0 , " Invalid block size" );
5555
5656 // Shared memory
57- static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0 );
57+ static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0 );
5858 static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof (__nv_bfloat16);
5959 static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof (__nv_fp8_e4m3);
6060 static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof (__nv_fp8_e4m3);
@@ -101,7 +101,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
101101
102102 // Fill shared memory pointers
103103 #pragma unroll
104- for (int i = 0 ; i < kNumStages ; ++ i) {
104+ for (uint32_t i = 0 ; i < kNumStages ; ++ i) {
105105 smem_a[i] = reinterpret_cast <__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
106106 smem_b[i] = reinterpret_cast <__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
107107 smem_scales_a[i] = reinterpret_cast <float *>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
@@ -111,7 +111,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
111111 // Fill barriers
112112 auto barrier_start_ptr = reinterpret_cast <Barrier*>(reinterpret_cast <uint8_t *>(smem_scales_b) + SMEM_SCALES_B_SIZE);
113113 #pragma unroll
114- for (int i = 0 ; i < kNumStages ; ++ i) {
114+ for (uint32_t i = 0 ; i < kNumStages ; ++ i) {
115115 full_barriers[i] = barrier_start_ptr + i;
116116 empty_barriers[i] = barrier_start_ptr + kNumStages + i;
117117 }
@@ -122,7 +122,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
122122 // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
123123 // even with TMA multicast disabled, we want to make the behavior aligned
124124 #pragma unroll
125- for (int i = 0 ; i < kNumStages ; ++ i) {
125+ for (uint32_t i = 0 ; i < kNumStages ; ++ i) {
126126 full_barriers[i]->init (1 );
127127 empty_barriers[i]->init (kNumTMAMulticast * kNumMathThreads / 32 );
128128 }
@@ -138,28 +138,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
138138 // For pipeline unrolling
139139 struct DivisibleK {};
140140 struct NotDivisibleK {};
141- auto launch_k_iterations = [](const auto & func, int num_former_iters) {
141+ struct SkipComputation {};
142+ struct NotSkipComputation {};
143+ auto launch_k_iterations = [](const auto & func, bool skip_computation, uint32_t num_former_iters) {
142144 constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd (BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB ;
143- constexpr int kGap = constexpr_gcd (BLOCK_K, BLOCK_N) / 8 ;
144- constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0 ;
145+ constexpr uint32_t kGap = constexpr_gcd (BLOCK_K, BLOCK_N) / 8 ;
146+ constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0 ;
145147
146148 // NOTES: for too-many branches (> 5), we disable this optimization
147149 // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
148- outer_launch_k_iterations<0 , kGap , kEnd >([](const auto & func, auto num_former_iters_type) {
149- if constexpr (SHAPE_K % kFullKOfAllStages == 0 ) {
150- for (int k_iter = 0 ; k_iter < kNumIterations ; ++ k_iter)
151- func (k_iter, DivisibleK{}, num_former_iters_type);
150+ outer_launch_k_iterations<0 , kGap , kEnd >([=](const auto & func, auto num_former_iters_type) {
151+ if (skip_computation) {
152+ for (uint32_t k_iter = 0 ; k_iter < kNumIterations ; ++ k_iter)
153+ func (k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
154+ } else if (SHAPE_K % kFullKOfAllStages == 0 ) {
155+ for (uint32_t k_iter = 0 ; k_iter < kNumIterations ; ++ k_iter)
156+ func (k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
152157 } else {
153- for (int k_iter = 0 ; k_iter < kNumIterations - 1 ; ++ k_iter)
154- func (k_iter, DivisibleK{}, num_former_iters_type);
155- func (kNumIterations - 1 , NotDivisibleK{}, num_former_iters_type);
158+ for (uint32_t k_iter = 0 ; k_iter < kNumIterations - 1 ; ++ k_iter)
159+ func (k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
160+ func (kNumIterations - 1 , NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
156161 }
157162 }, func, kShouldOptimize ? num_former_iters : 0 );
158163 };
159164
160165 // Register reconfigurations
161- constexpr int kNumTMARegisters = 40 ;
162- constexpr int kNumMathRegisters = 232 ;
166+ constexpr uint32_t kNumTMARegisters = 40 ;
167+ constexpr uint32_t kNumMathRegisters = 232 ;
163168
164169 // Block scheduler
165170 uint32_t m_block_idx, n_block_idx;
@@ -173,10 +178,9 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
173178 if (threadIdx .x == kNumMathThreads ) {
174179 // Persistently schedule over blocks
175180 while (scheduler.get_next_block (m_block_idx, n_block_idx)) {
176- launch_k_iterations ([&](int k_iter, auto type, auto _) {
177- constexpr bool kHasDivisibleStages = std::is_same_v<decltype (type), DivisibleK>;
178- constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages ) / BLOCK_K;
179- DG_STATIC_ASSERT (kNumInnerStages != 0 , " Invalid number of inner stages" );
181+ launch_k_iterations ([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
182+ constexpr bool kHasDivisibleStages = std::is_same_v<decltype (divisible_type), DivisibleK>;
183+ constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages ) / BLOCK_K;
180184
181185 // Assign TMA multicast number into A and B
182186 // NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
@@ -194,7 +198,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
194198
195199 // Issue TMA A
196200 auto & full_barrier = *full_barriers[s];
197- int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
201+ uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
198202 tma_copy (&tensor_map_a, reinterpret_cast <uint64_t *>(&full_barrier),
199203 smem_a[s], k_idx, scheduler.get_global_idx (shape_m, BLOCK_M, m_block_idx),
200204 num_tma_multicast_a);
@@ -216,7 +220,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
216220 empty_barriers[s]->wait ((scheduler.current_iter * kNumIterations + k_iter + 1 ) & 1 );
217221 full_barriers[s]->arrive ();
218222 }
219- }, 0 );
223+ }, false , 0 );
220224 }
221225
222226 // To safely deconstruct distributed shared barriers, we need another round of empty waits
@@ -257,12 +261,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
257261 cutlass::arch::NamedBarrier (kNumMathThreads ).sync ();
258262
259263 // Accumulation for WGMMA or CUDA promotion
260- constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups (BLOCK_M);
264+ constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups (BLOCK_M);
261265 DG_STATIC_ASSERT (BLOCK_M % WAVE_BLOCK_M == 0 , " Invalid block sizes" );
262266 float accum[WGMMA::kNumAccum ], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0 };
263267
264268 // Empty barrier arrival
265- auto empty_barrier_arrive = [&](int s) {
269+ auto empty_barrier_arrive = [&](uint32_t s) {
266270 if constexpr (kNumTMAMulticast == 1 ) {
267271 lane_idx == 0 ? empty_barriers[s]->arrive () : void ();
268272 } else {
@@ -272,13 +276,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
272276 };
273277
274278 // Launch MMAs
275- launch_k_iterations ([&](int k_iter, auto type, auto num_former_iters_type) {
276- constexpr bool kHasDivisibleStages = std::is_same_v<decltype (type), DivisibleK>;
277- constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages ) / BLOCK_K;
278- DG_STATIC_ASSERT (kNumInnerStages != 0 , " Invalid number of inner stages" );
279+ launch_k_iterations ([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
280+ constexpr bool kSkipComputation = std::is_same_v<decltype (skip_type), SkipComputation>;
281+ constexpr bool kHasDivisibleStages = std::is_same_v<decltype (divisible_type), DivisibleK>;
282+ constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 :
283+ (kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages ) / BLOCK_K);
279284
280285 #pragma unroll
281- for (int s = 0 ; s < kNumInnerStages ; ++ s) {
286+ for (uint32_t s = 0 ; s < kNumInnerStages ; ++ s) {
282287 // Read B scales
283288 float scale_b_0 = ld_shared (smem_scales_b + k_iter * kNumStages + s), scale_b_1;
284289 // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
@@ -300,18 +305,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
300305
301306 // Commit WGMMA instructions
302307 #pragma unroll
303- for (int i = 0 ; i < WGMMA::kNumAccum ; ++ i)
308+ for (uint32_t i = 0 ; i < WGMMA::kNumAccum ; ++ i)
304309 warpgroup_fence_operand (accum[i]);
305310 warpgroup_arrive ();
306311 #pragma unroll
307- for (int k = 0 ; k < BLOCK_K / WGMMA::K; ++ k) {
312+ for (uint32_t k = 0 ; k < BLOCK_K / WGMMA::K; ++ k) {
308313 auto desc_a = make_smem_desc (smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1 );
309314 auto desc_b = make_smem_desc (smem_b[s] + k * WGMMA::K, 1 );
310315 WGMMA::wgmma (desc_a, desc_b, accum, k);
311316 }
312317 warpgroup_commit_batch ();
313318 #pragma unroll
314- for (int i = 0 ; i < WGMMA::kNumAccum ; ++ i)
319+ for (uint32_t i = 0 ; i < WGMMA::kNumAccum ; ++ i)
315320 warpgroup_fence_operand (accum[i]);
316321 warpgroup_wait<0 >();
317322
@@ -328,7 +333,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
328333
329334 auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
330335 #pragma unroll
331- for (int i = 0 ; i < WGMMA::kNumAccum / 4 ; ++ i) {
336+ for (uint32_t i = 0 ; i < WGMMA::kNumAccum / 4 ; ++ i) {
332337 // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
333338 bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
334339 shifted_accum[i * 4 + 0 ] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0 ];
@@ -345,7 +350,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
345350 full_barriers[s]->wait ((scheduler.current_iter * kNumIterations + k_iter) & 1 );
346351 empty_barrier_arrive (s);
347352 }
348- }, num_former_iters);
353+ }, not scheduler. is_computation_valid (m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
349354
350355 // TMA checks
351356 constexpr uint32_t kNumElemBytes = sizeof (nv_bfloat16);
@@ -355,7 +360,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
355360 DG_STATIC_ASSERT (BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32 ,
356361 " Unaligned TMA store or too many TMA store instructions" );
357362 DG_STATIC_ASSERT (TMA_D_BLOCK_N % 8 == 0 , " Invalid TMA block N" );
358- DG_STATIC_ASSERT (static_cast <int >(kSwizzleDMode > 0 ) + static_cast <int >(BLOCK_N_PADDING > 0 ) <= 1 ,
363+ DG_STATIC_ASSERT (static_cast <uint32_t >(kSwizzleDMode > 0 ) + static_cast <uint32_t >(BLOCK_N_PADDING > 0 ) <= 1 ,
359364 " Swizzling and padding are not compatible" );
360365
361366 // Wait last TMA store to be finished
@@ -375,7 +380,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
375380 uint8_t * smem_ptr = nullptr ;
376381 if constexpr (kSwizzleDMode > 0 ) {
377382 // Calculate the swizzling atom offset and in-atom offset
378- constexpr int kNumBankGroupBytes = 16 ;
383+ constexpr uint32_t kNumBankGroupBytes = 16 ;
379384 auto atom_offset = i / (TMA_D_BLOCK_N / 8 ), in_atom_offset = i % (TMA_D_BLOCK_N / 8 );
380385
381386 // Calculate the index of the bank group to be written in the atom
@@ -436,4 +441,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
436441
437442}; // namespace deep_gemm
438443
439- #pragma clang diagnostic pop
444+ #pragma clang diagnostic pop
0 commit comments