Skip to content

Commit 8dfa329

Browse files
yukuai26yukuaiLyricZhao
authored
Grouped GEMM skip useless computation for unaligned Ms (#103)
* Grouped GEMM skip useless computation for unaligned Ms * Update readme.md * small typo * Rename variables * Restore previous indent * Format * Refactor tests * Add `SkipComputation` types * Bug fixed * Format * Fix tests * Add assertions * Minor fix --------- Co-authored-by: yukuai <yukuai@deepseek.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
1 parent 391755a commit 8dfa329

5 files changed

Lines changed: 106 additions & 93 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
1919
- [ ] Larger block size on N (up to 256)
2020
- [x] MoE scheduler with TMA multicast compatibility
2121
- [x] Fix TMA multicast compatibility for indivisible shapes
22-
- [ ] Skip useless computation on M
22+
- [x] Skip useless computation on M
2323
- [x] NVRTC as a faster compiler
2424
- [ ] Stolen JIT cache
2525
- [ ] Sanitizer for testing

deep_gemm/include/deep_gemm/fp8_gemm.cuh

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
namespace deep_gemm {
1919

20-
template <int kNumFormerIters, int kGap, int kEnd>
21-
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) {
20+
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
21+
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
2222
if (num_former_iters == kNumFormerIters) {
2323
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
2424
return;
@@ -54,7 +54,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
5454
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
5555

5656
// Shared memory
57-
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
57+
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
5858
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
5959
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
6060
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
@@ -101,7 +101,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
101101

102102
// Fill shared memory pointers
103103
#pragma unroll
104-
for (int i = 0; i < kNumStages; ++ i) {
104+
for (uint32_t i = 0; i < kNumStages; ++ i) {
105105
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
106106
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
107107
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
@@ -111,7 +111,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
111111
// Fill barriers
112112
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
113113
#pragma unroll
114-
for (int i = 0; i < kNumStages; ++ i) {
114+
for (uint32_t i = 0; i < kNumStages; ++ i) {
115115
full_barriers[i] = barrier_start_ptr + i;
116116
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
117117
}
@@ -122,7 +122,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
122122
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
123123
// even with TMA multicast disabled, we want to make the behavior aligned
124124
#pragma unroll
125-
for (int i = 0; i < kNumStages; ++ i) {
125+
for (uint32_t i = 0; i < kNumStages; ++ i) {
126126
full_barriers[i]->init(1);
127127
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
128128
}
@@ -138,28 +138,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
138138
// For pipeline unrolling
139139
struct DivisibleK {};
140140
struct NotDivisibleK {};
141-
auto launch_k_iterations = [](const auto& func, int num_former_iters) {
141+
struct SkipComputation {};
142+
struct NotSkipComputation {};
143+
auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) {
142144
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
143-
constexpr int kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
144-
constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
145+
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
146+
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
145147

146148
// NOTES: for too-many branches (> 5), we disable this optimization
147149
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
148-
outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) {
149-
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
150-
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
151-
func(k_iter, DivisibleK{}, num_former_iters_type);
150+
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
151+
if (skip_computation) {
152+
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
153+
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
154+
} else if (SHAPE_K % kFullKOfAllStages == 0) {
155+
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
156+
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
152157
} else {
153-
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
154-
func(k_iter, DivisibleK{}, num_former_iters_type);
155-
func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type);
158+
for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
159+
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
160+
func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
156161
}
157162
}, func, kShouldOptimize ? num_former_iters : 0);
158163
};
159164

160165
// Register reconfigurations
161-
constexpr int kNumTMARegisters = 40;
162-
constexpr int kNumMathRegisters = 232;
166+
constexpr uint32_t kNumTMARegisters = 40;
167+
constexpr uint32_t kNumMathRegisters = 232;
163168

164169
// Block scheduler
165170
uint32_t m_block_idx, n_block_idx;
@@ -173,10 +178,9 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
173178
if (threadIdx.x == kNumMathThreads) {
174179
// Persistently schedule over blocks
175180
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
176-
launch_k_iterations([&](int k_iter, auto type, auto _) {
177-
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
178-
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
179-
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
181+
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
182+
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
183+
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
180184

181185
// Assign TMA multicast number into A and B
182186
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
@@ -194,7 +198,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
194198

195199
// Issue TMA A
196200
auto& full_barrier = *full_barriers[s];
197-
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
201+
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
198202
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
199203
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
200204
num_tma_multicast_a);
@@ -216,7 +220,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
216220
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
217221
full_barriers[s]->arrive();
218222
}
219-
}, 0);
223+
}, false, 0);
220224
}
221225

222226
// To safely deconstruct distributed shared barriers, we need another round of empty waits
@@ -257,12 +261,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
257261
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
258262

259263
// Accumulation for WGMMA or CUDA promotion
260-
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
264+
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
261265
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
262266
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
263267

264268
// Empty barrier arrival
265-
auto empty_barrier_arrive = [&](int s) {
269+
auto empty_barrier_arrive = [&](uint32_t s) {
266270
if constexpr (kNumTMAMulticast == 1) {
267271
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
268272
} else {
@@ -272,13 +276,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
272276
};
273277

274278
// Launch MMAs
275-
launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) {
276-
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
277-
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
278-
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
279+
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
280+
constexpr bool kSkipComputation = std::is_same_v<decltype(skip_type), SkipComputation>;
281+
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
282+
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 :
283+
(kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K);
279284

280285
#pragma unroll
281-
for (int s = 0; s < kNumInnerStages; ++ s) {
286+
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
282287
// Read B scales
283288
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
284289
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
@@ -300,18 +305,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
300305

301306
// Commit WGMMA instructions
302307
#pragma unroll
303-
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
308+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
304309
warpgroup_fence_operand(accum[i]);
305310
warpgroup_arrive();
306311
#pragma unroll
307-
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
312+
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
308313
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
309314
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
310315
WGMMA::wgmma(desc_a, desc_b, accum, k);
311316
}
312317
warpgroup_commit_batch();
313318
#pragma unroll
314-
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
319+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
315320
warpgroup_fence_operand(accum[i]);
316321
warpgroup_wait<0>();
317322

@@ -328,7 +333,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
328333

329334
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
330335
#pragma unroll
331-
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
336+
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
332337
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
333338
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
334339
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
@@ -345,7 +350,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
345350
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
346351
empty_barrier_arrive(s);
347352
}
348-
}, num_former_iters);
353+
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
349354

350355
// TMA checks
351356
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
@@ -355,7 +360,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
355360
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
356361
"Unaligned TMA store or too many TMA store instructions");
357362
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
358-
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
363+
DG_STATIC_ASSERT(static_cast<uint32_t>(kSwizzleDMode > 0) + static_cast<uint32_t>(BLOCK_N_PADDING > 0) <= 1,
359364
"Swizzling and padding are not compatible");
360365

361366
// Wait last TMA store to be finished
@@ -375,7 +380,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
375380
uint8_t* smem_ptr = nullptr;
376381
if constexpr (kSwizzleDMode > 0) {
377382
// Calculate the swizzling atom offset and in-atom offset
378-
constexpr int kNumBankGroupBytes = 16;
383+
constexpr uint32_t kNumBankGroupBytes = 16;
379384
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
380385

381386
// Calculate the index of the bank group to be written in the atom
@@ -436,4 +441,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
436441

437442
}; // namespace deep_gemm
438443

439-
#pragma clang diagnostic pop
444+
#pragma clang diagnostic pop

deep_gemm/include/deep_gemm/scheduler.cuh

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct Scheduler {
3434
// Only used for masked layout
3535
uint32_t curr_group_idx, curr_cumsum;
3636

37-
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
37+
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m,
3838
int* grouped_layout = nullptr) {
3939
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
4040
if constexpr (kGemmType == GemmType::Normal) {
@@ -48,6 +48,17 @@ struct Scheduler {
4848
}
4949
}
5050

51+
// ReSharper disable once CppNotAllPathsReturnValue
52+
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
53+
if constexpr (kGemmType == GemmType::Normal) {
54+
return true;
55+
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
56+
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
57+
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
58+
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
59+
}
60+
}
61+
5162
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
5263
if (num_blocks_in_group == 1)
5364
return false;
@@ -65,7 +76,7 @@ struct Scheduler {
6576
}
6677
}
6778

68-
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx,
79+
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx,
6980
uint32_t& m_block_idx, uint32_t& n_block_idx) {
7081
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
7182

@@ -100,7 +111,7 @@ struct Scheduler {
100111
}
101112

102113
template <bool kIgnoreGroupedForGroupedContiguous=true>
103-
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
114+
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size,
104115
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
105116
if constexpr (kGemmType == GemmType::Normal) {
106117
return block_idx * block_size;

deep_gemm/jit/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def flags() -> List[str]:
121121
'--ptxas-options=--register-usage-level=10' +
122122
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
123123
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
124-
'--diag-suppress=39,161,174,177,940']
124+
'--diag-suppress=39,161,174,177,186,940']
125125

126126
@staticmethod
127127
def include_dirs() -> List[str]:

0 commit comments

Comments
 (0)