Skip to content

Commit ede914e

Browse files
Brandon Musicclaude
andcommitted
fix: K=64 block-scaled GEMM TMA layout fixes for SM120
Two fixes enabling K=64 block-scaled MoE GEMM on SM120 (99KB SMEM): 1. copy_traits_sm90_tma.hpp: Handle zero-stride basis elements in fill_tma_gmem_shape_stride. When a basis element is constant-zero (broadcast dimension for SFVectorSize), basis_get returns the entire tuple instead of a scalar. Detect is_constant<0> and set shape=1, stride=0 directly. 2. sm120_blockscaled_mma_builder.inl: Clamp Blk_SF to min(K/SFVectorSize, Blk_SF) and fold the effective block into kBasicBlockShape when the tile K is too small for the default block size. This keeps outer K dimensions trivial so TMA can construct valid descriptors. For K=64 with SFVectorSize=32: K/SFVectorSize=2 < Blk_SF=4, which previously created a zero-size dimension in the scale factor SMEM layout, triggering "TMA requires CTA_Tile and SLayout top-level size equivalence." Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f3fde58 commit ede914e

2 files changed

Lines changed: 42 additions & 14 deletions

File tree

include/cute/atom/copy_traits_sm90_tma.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,14 @@ fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, /
869869
if constexpr (tma_i_rank == 1) {
870870
// Trivial contribution of this gmem mode to this tma mode
871871
auto ej = unwrap(get<i>(tma_gbasis_stride));
872-
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
873-
gmem_prob_stride[i] = basis_get(ej, gmem_stride);
872+
if constexpr (cute::is_constant<0, decltype(ej)>::value) {
873+
// Zero-stride basis: broadcast dimension (e.g. SFVectorSize), shape=1, stride=0
874+
gmem_prob_shape[i] = 1;
875+
gmem_prob_stride[i] = 0;
876+
} else {
877+
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
878+
gmem_prob_stride[i] = basis_get(ej, gmem_stride);
879+
}
874880
} else {
875881
// Apply a recurrence to each gmem mode that contributes to this tma mode
876882
for_each(get<i>(tma_gbasis_stride), [&](auto ej) {

include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************************************
2-
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: BSD-3-Clause
44
*
55
* Redistribution and use in source and binary forms, with or without
@@ -177,31 +177,53 @@ struct CollectiveBuilder<
177177
using SmemCopyAtomsB = decltype(cute::make_tuple(SmemCopyAtomB{}, SmemCopyAtomSFB{}));
178178

179179
// Construct SMEM layout for SF
180-
// A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix).
181-
// 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size
180+
// A single indivisible block will hold Blk_SF (4) scale factors of 128 rows/columns (A/B matrix).
181+
// 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size
182182
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;
183-
using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF;
184-
using Blk_Elems = decltype(Blk_MN{} * Blk_SF{});
183+
using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF;
184+
185+
// For tiles where K/SFVectorSize < Blk_SF (e.g. K=64 with SFVectorSize=32 gives only 2 SF
186+
// values along K, but Blk_SF=4), clamp the effective block size to avoid TMA layout issues.
187+
// When EffBlk_SF < Blk_SF AND MMA_NSF < EffBlk_SF, we fold EffBlk_SF into the kBasicBlock
188+
// so that the outer K shape is trivial (all 1s) and gets collapsed by TMA, avoiding nested
189+
// tuple types that can't convert to uint64_t in fill_tma_gmem_shape_stride.
190+
static constexpr int NumSFAlongK = size<2>(TileShape_MNK{}) / SFVectorSize;
191+
using EffBlk_SF = Int<(NumSFAlongK < Blk_SF{}) ? NumSFAlongK : int(Blk_SF{})>;
192+
using EffBlk_Elems = decltype(Blk_MN{} * EffBlk_SF{});
193+
194+
// Determine if we need to fold EffBlk_SF into the basic block to keep TMA layouts flat.
195+
// This is needed when EffBlk_SF > MMA_NSF (i.e. the outer K shape would be non-trivial).
196+
static constexpr bool FoldSFIntoBasicBlock = (NumSFAlongK < Blk_SF{}) && (int(EffBlk_SF{}) > MMA_NSF);
185197

186198
// Basic storage block for new Scaling Factor Layouts
187199
using mnBasicBlockShape = Shape<_32,_4>;
188200
using mnBasicBlockStride = Stride<_16,_4>;
189-
using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>;
201+
// When folding: kBasicBlock absorbs EffBlk_SF, making outer K shape all-1 (trivially collapsed by TMA)
202+
// When not folding: original kBasicBlock with MMA_NSF
203+
using kBasicBlockShape = cute::conditional_t<FoldSFIntoBasicBlock,
204+
Shape<Int<SFVectorSize>, EffBlk_SF>,
205+
Shape<Int<SFVectorSize>, Int<MMA_NSF>>>;
190206
using kBasicBlockStride = Stride<_0, _1>;
191-
207+
208+
// Outer K shape: when folded, both dimensions are 1 (trivial); when not folded, original formula
209+
using OuterK0 = cute::conditional_t<FoldSFIntoBasicBlock, _1, decltype(EffBlk_SF{}/Int<MMA_NSF>{})>;
210+
using OuterK1 = decltype(size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / EffBlk_SF{});
211+
// Outer K stride: first element is EffBlk_SF when folded (doesn't matter since shape=1), MMA_NSF when not
212+
using OuterKS0 = cute::conditional_t<FoldSFIntoBasicBlock, EffBlk_SF, Int<MMA_NSF>>;
213+
192214
using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
193-
using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{}));
215+
using sSF_strideMN = decltype(prepend( EffBlk_Elems{}, mnBasicBlockStride{}));
194216
using sSFA_strideM = sSF_strideMN;
195-
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
196-
197-
using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
217+
using sSF_shapeK = decltype(prepend(make_shape( OuterK0{}, OuterK1{}), kBasicBlockShape{}));
218+
219+
using sSFA_strideK = decltype(prepend(make_stride( OuterKS0{}, size<0>(TileShape_MNK{}) / Blk_MN{} * EffBlk_Elems{}), kBasicBlockStride{}));
198220
using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{}));
199221
using sSFA_stride = decltype(make_stride(sSFA_strideM{}, sSFA_strideK{}));
200222
using SmemLayoutAtomSFA = decltype(make_layout( sSFA_shape{}, sSFA_stride{}));
201223

202224
using sSFB_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
203225
using sSFB_strideN = sSF_strideMN;
204-
using sSFB_strideK = decltype(prepend(make_stride(Int<MMA_NSF>{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
226+
using sSFB_strideK = decltype(prepend(make_stride(OuterKS0{}, size<1>(TileShape_MNK{}) / Blk_MN{} * EffBlk_Elems{}), kBasicBlockStride{}));
205227
using sSFB_shape = decltype(make_shape( sSFB_shapeN{}, sSF_shapeK{}));
206228
using sSFB_stride = decltype(make_stride(sSFB_strideN{}, sSFB_strideK{}));
207229
using SmemLayoutAtomSFB = decltype(make_layout( sSFB_shape{}, sSFB_stride{}));

0 commit comments

Comments
 (0)