|
1 | 1 | /*************************************************************************************************** |
2 | | - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
3 | 3 | * SPDX-License-Identifier: BSD-3-Clause |
4 | 4 | * |
5 | 5 | * Redistribution and use in source and binary forms, with or without |
@@ -177,31 +177,53 @@ struct CollectiveBuilder< |
177 | 177 | using SmemCopyAtomsB = decltype(cute::make_tuple(SmemCopyAtomB{}, SmemCopyAtomSFB{})); |
178 | 178 |
|
179 | 179 | // Construct SMEM layout for SF |
180 | | - // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). |
181 | | - // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size |
| 180 | + // A single indivisible block will hold Blk_SF (4) scale factors of 128 rows/columns (A/B matrix). |
| 181 | + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size |
182 | 182 | using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; |
183 | | - using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; |
184 | | - using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); |
| 183 | + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; |
| 184 | + |
| 185 | + // For tiles where K/SFVectorSize < Blk_SF (e.g. K=64 with SFVectorSize=32 gives only 2 SF |
| 186 | + // values along K, but Blk_SF=4), clamp the effective block size to avoid TMA layout issues. |
| 187 | + // When EffBlk_SF < Blk_SF AND MMA_NSF < EffBlk_SF, we fold EffBlk_SF into the kBasicBlock |
| 188 | + // so that the outer K shape is trivial (all 1s) and gets collapsed by TMA, avoiding nested |
| 189 | + // tuple types that can't convert to uint64_t in fill_tma_gmem_shape_stride. |
| 190 | + static constexpr int NumSFAlongK = size<2>(TileShape_MNK{}) / SFVectorSize; |
| 191 | + using EffBlk_SF = Int<(NumSFAlongK < Blk_SF{}) ? NumSFAlongK : int(Blk_SF{})>; |
| 192 | + using EffBlk_Elems = decltype(Blk_MN{} * EffBlk_SF{}); |
| 193 | + |
| 194 | + // Determine if we need to fold EffBlk_SF into the basic block to keep TMA layouts flat. |
| 195 | + // This is needed when EffBlk_SF > MMA_NSF (i.e. the outer K shape would be non-trivial). |
| 196 | + static constexpr bool FoldSFIntoBasicBlock = (NumSFAlongK < Blk_SF{}) && (int(EffBlk_SF{}) > MMA_NSF); |
185 | 197 |
|
186 | 198 | // Basic storage block for new Scaling Factor Layouts |
187 | 199 | using mnBasicBlockShape = Shape<_32,_4>; |
188 | 200 | using mnBasicBlockStride = Stride<_16,_4>; |
189 | | - using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>; |
| 201 | + // When folding: kBasicBlock absorbs EffBlk_SF, making outer K shape all-1 (trivially collapsed by TMA) |
| 202 | + // When not folding: original kBasicBlock with MMA_NSF |
| 203 | + using kBasicBlockShape = cute::conditional_t<FoldSFIntoBasicBlock, |
| 204 | + Shape<Int<SFVectorSize>, EffBlk_SF>, |
| 205 | + Shape<Int<SFVectorSize>, Int<MMA_NSF>>>; |
190 | 206 | using kBasicBlockStride = Stride<_0, _1>; |
191 | | - |
| 207 | + |
| 208 | + // Outer K shape: when folded, both dimensions are 1 (trivial); when not folded, original formula |
| 209 | + using OuterK0 = cute::conditional_t<FoldSFIntoBasicBlock, _1, decltype(EffBlk_SF{}/Int<MMA_NSF>{})>; |
| 210 | + using OuterK1 = decltype(size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / EffBlk_SF{}); |
| 211 | + // Outer K stride: first element is EffBlk_SF when folded (doesn't matter since shape=1), MMA_NSF when not |
| 212 | + using OuterKS0 = cute::conditional_t<FoldSFIntoBasicBlock, EffBlk_SF, Int<MMA_NSF>>; |
| 213 | + |
192 | 214 | using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); |
193 | | - using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{})); |
| 215 | + using sSF_strideMN = decltype(prepend( EffBlk_Elems{}, mnBasicBlockStride{})); |
194 | 216 | using sSFA_strideM = sSF_strideMN; |
195 | | - using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{})); |
196 | | - |
197 | | - using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); |
| 217 | + using sSF_shapeK = decltype(prepend(make_shape( OuterK0{}, OuterK1{}), kBasicBlockShape{})); |
| 218 | + |
| 219 | + using sSFA_strideK = decltype(prepend(make_stride( OuterKS0{}, size<0>(TileShape_MNK{}) / Blk_MN{} * EffBlk_Elems{}), kBasicBlockStride{})); |
198 | 220 | using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{})); |
199 | 221 | using sSFA_stride = decltype(make_stride(sSFA_strideM{}, sSFA_strideK{})); |
200 | 222 | using SmemLayoutAtomSFA = decltype(make_layout( sSFA_shape{}, sSFA_stride{})); |
201 | 223 |
|
202 | 224 | using sSFB_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); |
203 | 225 | using sSFB_strideN = sSF_strideMN; |
204 | | - using sSFB_strideK = decltype(prepend(make_stride(Int<MMA_NSF>{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); |
| 226 | + using sSFB_strideK = decltype(prepend(make_stride(OuterKS0{}, size<1>(TileShape_MNK{}) / Blk_MN{} * EffBlk_Elems{}), kBasicBlockStride{})); |
205 | 227 | using sSFB_shape = decltype(make_shape( sSFB_shapeN{}, sSF_shapeK{})); |
206 | 228 | using sSFB_stride = decltype(make_stride(sSFB_strideN{}, sSFB_strideK{})); |
207 | 229 | using SmemLayoutAtomSFB = decltype(make_layout( sSFB_shape{}, sSFB_stride{})); |
|
0 commit comments