Skip to content

[QST] Static assertion failed when using swizzled layout with gemm #1766

@interestingLSY

Description

@interestingLSY

What is your question?

Hello, I have encountered a "static assertion failed" when running gemm on swizzled layout. It would be very much appreciated to be answered.

Here is my code:

#include <cutlass/cutlass.h>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>
#include <cute/arch/mma_sm80.hpp>
using namespace cute;

__global__ void test_kernel() {
	// Declare tiled_mma
	auto tiled_mma = make_tiled_mma(
		SM80_16x8x16_F16F16F16F16_TN{},
		Layout<Shape<_1, _1, _1>>{},
		Tile<_16, _16, _16>{}
	);
	ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);

	// Declare shared memory tensor layouts
	using LayoutA = Layout<Shape<Int<128>, Int<32>>, Stride<Int<32>, Int<1>>>;
        using LayoutB = decltype(composition(Swizzle<3, 3, 4>{}, Layout<Shape<Int<128>, Int<32>>, Stride<Int<1>, Int<128>>>{}));
        // using LayoutB = Layout<Shape<Int<128>, Int<32>>, Stride<Int<1>, Int<128>>>;	// If using this, the code compiles
        using LayoutC = Layout<Shape<Int<128>, Int<128>>, Stride<Int<128>, Int<1>>>;

	// Declare shared memory tensors
	extern __shared__ char smem[];
	Tensor sA = make_tensor(make_smem_ptr(smem), LayoutA{});
	Tensor sB = make_tensor(make_smem_ptr(smem), LayoutB{});
	Tensor sC = make_tensor(make_smem_ptr(smem), LayoutC{});

	// Declare register-file fragments
	Tensor mma_rA = thr_mma.partition_fragment_A(sA);
	Tensor mma_rB = thr_mma.partition_fragment_B(sB);
	Tensor mma_rC = thr_mma.partition_fragment_C(sC);

	// Try to run GEMM
	gemm(tiled_mma, mma_rC, mma_rA(_, _, _0{}), mma_rB(_, _, _0{}), mma_rC);
}

int main() {
	test_kernel<<<1, 32, 32768>>>();
}

When compiling the code, it results in the following error:

nvcc playground.cu -o playground.exe -I/home/intlsy/research/mirage/mirage/deps/cutlass/include --expt-relaxed-constexpr -use_fast_math -Xcompiler=-O2 -Xcompiler=-march=native -diag-suppress 2361 -std=c++20 -arch=native -lcublas

/home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/atom/mma_traits.hpp(139): error: static assertion failed
    static_assert(decltype(size(rB) == Int<RegNumB>{})::value);
    ^
          detected during:
            instantiation of "void cute::mma_unpack(const cute::MMA_Traits<MMA_Op, MMA_Args...> &, cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA_Op=cute::SM80_16x8x16_F16F16F16F16_TN, MMA_Args=<>, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>]" at line 105 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/atom/mma_atom.hpp
            instantiation of "void cute::MMA_Atom<cute::MMA_Traits<Args...>>::call(cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) const [with Args=<cute::SM80_16x8x16_F16F16F16F16_TN>, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>]" at line 197 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/algorithm/gemm.hpp
            instantiation of "void cute::gemm(const cute::MMA_Atom<MMA> &, cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA=cute::SM80_16x8x16_F16F16F16F16_TN, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, <unnamed>=(void *)nullptr]" at line 148 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/algorithm/gemm.hpp
            instantiation of "void cute::gemm(const cute::MMA_Atom<MMA> &, cute::Tensor<TD, DLayout> &&, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA=cute::SM80_16x8x16_F16F16F16F16_TN, TD=cute::ViewEngine<cutlass::half_t *>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<int, cute::_2>>>, TC=cute::ViewEngine<const cutlass::half_t *>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>>, cute::tuple<cute::tuple<cute::_1, cute::_2>>>]" at line 382 of /home/intlsy/research/mirage/mirage/deps/cutlass/include/cute/algorithm/gemm.hpp
            instantiation of "void cute::gemm(const cute::MMA_Atom<MMA> &, cute::Tensor<TD, DLayout> &, const cute::Tensor<TA, ALayout> &, const cute::Tensor<TB, BLayout> &, const cute::Tensor<TC, CLayout> &) [with MMA=cute::SM80_16x8x16_F16F16F16F16_TN, TD=cute::ArrayEngine<cutlass::half_t, 512>, DLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>, cute::C<8>, cute::C<16>>, cute::tuple<cute::tuple<cute::_1, cute::_2>, cute::_4, cute::_32>>, TA=cute::ViewEngine<cutlass::half_t *>, ALayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::_8>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::_4>, cute::_16>>, TB=cute::ViewEngine<cutlass::half_t *>, BLayout=cute::ComposedLayout<cute::Swizzle<1, 0, -2>, cute::C<0U>, cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>, cute::_16>, cute::tuple<cute::tuple<cute::_1, cute::_2>, cute::_4>>>, TC=cute::ArrayEngine<cutlass::half_t, 512>, CLayout=cute::Layout<cute::tuple<cute::tuple<cute::_2, cute::_2>, cute::C<8>, cute::C<16>>, cute::tuple<cute::tuple<cute::_1, cute::_2>, cute::_4, cute::_32>>, <unnamed>=(void *)nullptr]" at line 35 of playground.cu

If I remove the swizzle in LayoutB, the code compiles.

And here is the output of print_layout(LayoutB{}), which I think is correct: link

I don't quite understand the reason. Any help would be appreciated.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions