Hello, I have encountered a "static assertion failed" when running gemm on swizzled layout. It would be very much appreciated to be answered.
#include <cutlass/cutlass.h>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>
#include <cute/arch/mma_sm80.hpp>
using namespace cute;
__global__ void test_kernel() {
// Declare tiled_mma
auto tiled_mma = make_tiled_mma(
SM80_16x8x16_F16F16F16F16_TN{},
Layout<Shape<_1, _1, _1>>{},
Tile<_16, _16, _16>{}
);
ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);
// Declare shared memory tensor layouts
using LayoutA = Layout<Shape<Int<128>, Int<32>>, Stride<Int<32>, Int<1>>>;
using LayoutB = decltype(composition(Swizzle<3, 3, 4>{}, Layout<Shape<Int<128>, Int<32>>, Stride<Int<1>, Int<128>>>{}));
// using LayoutB = Layout<Shape<Int<128>, Int<32>>, Stride<Int<1>, Int<128>>>; // If using this, the code compiles
using LayoutC = Layout<Shape<Int<128>, Int<128>>, Stride<Int<128>, Int<1>>>;
// Declare shared memory tensors
extern __shared__ char smem[];
Tensor sA = make_tensor(make_smem_ptr(smem), LayoutA{});
Tensor sB = make_tensor(make_smem_ptr(smem), LayoutB{});
Tensor sC = make_tensor(make_smem_ptr(smem), LayoutC{});
// Declare register-file fragments
Tensor mma_rA = thr_mma.partition_fragment_A(sA);
Tensor mma_rB = thr_mma.partition_fragment_B(sB);
Tensor mma_rC = thr_mma.partition_fragment_C(sC);
// Try to run GEMM
gemm(tiled_mma, mma_rC, mma_rA(_, _, _0{}), mma_rB(_, _, _0{}), mma_rC);
}
int main() {
test_kernel<<<1, 32, 32768>>>();
}
nvcc playground.cu -o playground.exe -I/home/intlsy/research/mirage/mirage/deps/cutlass/include --expt-relaxed-constexpr -use_fast_math -Xcompiler=-O2 -Xcompiler=-march=native -diag-suppress 2361 -std=c++20 -arch=native -lcublas
/home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/atom/mma_traits.hpp(139): error: static assertion failed
static_assert(decltype(size(rB) == Int<RegNumB>{})::value);
^
detected during:
instantiation of "void cute::mma_unpack(const cute::MMA_Traits<MMA_Op, MMA_Args...> &, cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA_Op=cute::SM80_16x8x16_F16F16F16F16_TN, MMA_Args=<>, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>]" at line 105 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/atom/mma_atom.hpp
instantiation of "void cute::MMA_Atom<cute::MMA_Traits<Args...>>::call(cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) const [with Args=<cute::SM80_16x8x16_F16F16F16F16_TN>, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>]" at line 197 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/algorithm/gemm.hpp
instantiation of "void cute::gemm(const cute::MMA_Atom<MMA> &, cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA=cute::SM80_16x8x16_F16F16F16F16_TN, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, <unnamed>=(void *)nullptr]" at line 148 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/algorithm/gemm.hpp
instantiation of "void cute::gemm(const cute::MMA_Atom<MMA> &, cute::Tensor<TD, DLayout> &&, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA=cute::SM80_16x8x16_F16F16F16F16_TN, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>]" at line 382 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/algorithm/gemm.hpp
instantiation of "void cute::gemm(const cute::MMA_Atom<MMA> &, cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA=cute::SM80_16x8x16_F16F16F16F16_TN, TD=cute::ArrayEngine<cutlass::half_t, 512>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>, cute::C<8>, cute::C<16>>, cute::tuple<cute::tuple<cute::_1, cute::_2>, cute::_4, cute::_32>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::_8>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>, cute::_16>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::ComposedLayout<cute::Swizzle<1, 0, -2>, cute::C<0U>, cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>, cute::_16>, cute::tuple<cute::tuple<cute::_1, cute::_2>, cute::_4>>>, TC=cute::ArrayEngine<cutlass::half_t, 512>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>, cute::C<8>, cute::C<16>>, cute::tuple<cute::tuple<cute::_1, cute::_2>, cute::_4, cute::_32>>, <unnamed>=(void *)nullptr]" at line 35 of playground.cu
I don't quite understand the reason. Any help would be appreciated.
What is your question?
Hello, I have encountered a "static assertion failed" when running gemm on swizzled layout. It would be very much appreciated to be answered.
Here is my code:
When compiling the code, it results in the following error:
If I remove the
swizzleinLayoutB, the code compiles.And here is the output of
print_layout(LayoutB{}), which I think is correct: linkI don't quite understand the reason. Any help would be appreciated.