Skip to content

[QST] GemmUniversal is slower than GemmSplitKParallel when M and N are small and K is large #1586

@hychiang-git

Description

@hychiang-git

Hello,

I read this issue:

  • kernel::GemmUniversal with mode GemmUniversalMode::kGemmSplitKParallel will be equivalent to kernel::GemmSplitKParallel. The difference comes to fore for the device::-scoped kernels, wherein device::GemmSplitKParallel calls a reduction kernel and device::GemmUniversal does not. However, it is recommended that you use device::GemmUniversal rather than device::GemmSplitKParallel, as the former is more-frequently tested.

Originally posted by @jackkosaian in #702 (comment)

However, I tested two implementations and found that GemmUniversal is much slower than GemmSplitKParallel when M and N are small and K is large, for example, M=64, N=64, K=4096.

GemmSplitKParallel: 0.011651 ms
UniversalGemmStreamK: 0.083712 ms

How could I configure the GemmUniversal to reproduce the speed of GemmSplitKParallel for M=64, N=64, K=4096? Thanks!

I profile GEMMs with Cutlass v3.4.1 on an A5000 GPU.

Here is my testing code.

#include <iostream>

#include "cuda_runtime.h"

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"

// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu
void run_UniversalGemmStreamK(int m, int n, int k, int n_iter) {

    using ElementAccumulator = float;                   // <- data type of accumulator
    using ElementComputeEpilogue = ElementAccumulator;  // <- data type of epilogue operations
    using ElementInputA = cutlass::half_t;              // <- data type of elements in input matrix A
    using ElementInputB = cutlass::half_t;              // <- data type of elements in input matrix B
    using ElementOutput = float;                        // <- data type of elements in output matrix D

    using LayoutInputA = cutlass::layout::RowMajor;
    using LayoutInputB = cutlass::layout::ColumnMajor;
    using LayoutOutput = cutlass::layout::RowMajor;

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    // Initialize tensors using CUTLASS helper functions
    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());

    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0);  // <- Fill matrix A on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0);  // <- Fill matrix B on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_c.host_view(),
        1,
        ElementOutput(4),
        ElementOutput(-4),
        0);  // <- Fill matrix C on host with uniform-distribution random data
    cutlass::reference::host::TensorFill(
        tensor_d.host_view());  // <- fill matrix D on host with zeros

    // Copy data from host to GPU
    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();

    using MMAOp = cutlass::arch::OpClassTensorOp;
    using SmArch = cutlass::arch::Sm80;
    using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
    using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
    using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

    using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
        ElementOutput,
        128 / cutlass::sizeof_bits<ElementOutput>::value,
        ElementAccumulator,
        ElementComputeEpilogue>;

  constexpr int NumStages  = 4; 
  // StreamK device GEMM implementation type
  using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal<
      ElementInputA, LayoutInputA,
      ElementInputB, LayoutInputB,
      ElementOutput, LayoutOutput,
      ElementAccumulator,
      MMAOp,
      SmArch,
      ShapeMMAThreadBlock,
      ShapeMMAWarp,
      ShapeMMAOp,
      EpilogueOp,
      cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
      NumStages,
      128 / cutlass::sizeof_bits<ElementInputA>::value,
      128 / cutlass::sizeof_bits<ElementInputB>::value>;

  // Initialize alpha and beta for dot product computation
  ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
  ElementComputeEpilogue beta = ElementComputeEpilogue(0);

  typename DeviceGemmStreamK::Arguments arguments{
    cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel,  // kGemmSplitKParallel mode
    problem_size,                     // problem_size
    16,                               // batch count / splitk slices
    {alpha, beta},                     // epilogue parameters
    tensor_a.device_data(),                   // ptr_A
    tensor_b.device_data(),                   // ptr_B
    tensor_c.device_data(),                   // ptr_C
    tensor_d.device_data(),                   // ptr_D
    problem_size.mk().product(),      // batch_stride_A
    problem_size.nk().product(),      // batch_stride_B
    problem_size.mn().product(),      // batch_stride_C
    problem_size.mn().product(),      // batch_stride_D
    tensor_a.layout().stride(0),              // stride_a
    tensor_b.layout().stride(0),              // stride_b
    tensor_c.layout().stride(0),              // stride_c
    tensor_d.layout().stride(0),              // stride_d
    0};

  DeviceGemmStreamK gemm_op;
  size_t workspace_size = DeviceGemmStreamK::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  cudaEvent_t start, stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);

  // warmup
  cutlass::Status status;
  status = gemm_op(arguments, workspace.get());
  CUTLASS_CHECK(status);
  // time
  cudaEventRecord(start);
  for (int i=0; i<n_iter; i++) {
    status = gemm_op(arguments, workspace.get());
    CUTLASS_CHECK(status);
  }
  cudaEventRecord(stop);
  cudaEventSynchronize(stop);
  // print
  float milliseconds = 0;
  cudaEventElapsedTime(&milliseconds, start, stop);
  printf("UniversalGemmStreamK: %f ms\n", milliseconds/n_iter);
}

// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/06_splitK_gemm/splitk_gemm.cu
void run_GemmSplitKParallel(int m, int n, int k, int n_iter) {

    using ElementAccumulator = float;                   // <- data type of accumulator
    using ElementComputeEpilogue = ElementAccumulator;  // <- data type of epilogue operations
    using ElementInputA = cutlass::half_t;              // <- data type of elements in input matrix A
    using ElementInputB = cutlass::half_t;              // <- data type of elements in input matrix B
    using ElementOutput = float;                        // <- data type of elements in output matrix D

    using LayoutInputA = cutlass::layout::RowMajor;
    using LayoutInputB = cutlass::layout::ColumnMajor;
    using LayoutOutput = cutlass::layout::RowMajor;

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    // Initialize tensors using CUTLASS helper functions
    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());

    // Fill input and output matrices on host using CUTLASS helper functions
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0);  // <- Fill matrix A on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0);  // <- Fill matrix B on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_c.host_view(),
        1,
        ElementOutput(4),
        ElementOutput(-4),
        0);  // <- Fill matrix C on host with uniform-distribution random data
    cutlass::reference::host::TensorFill(
        tensor_d.host_view());  // <- fill matrix D on host with zeros

    // Copy data from host to GPU
    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();

    using MMAOp = cutlass::arch::OpClassTensorOp;
    using SmArch = cutlass::arch::Sm80;
    using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
    using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
    using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

    // This code section describes ?
    using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
        ElementOutput,
        128 / cutlass::sizeof_bits<ElementOutput>::value,
        ElementAccumulator,
        ElementComputeEpilogue>;

    using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
                                                        LayoutInputA,
                                                        ElementInputB,
                                                        LayoutInputB,
                                                        ElementOutput,
                                                        LayoutOutput,
                                                        ElementAccumulator,
                                                        MMAOp,
                                                        SmArch,
                                                        ShapeMMAThreadBlock,
                                                        ShapeMMAWarp,
                                                        ShapeMMAOp,
                                                        EpilogueOp>;

  ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
  ElementComputeEpilogue beta = ElementComputeEpilogue(0);

  // split K dimension into 16 partitions
  int split_k_slices = 16;
  typename Gemm::Arguments arguments{problem_size,  // <- problem size of matrix multiplication
                                     tensor_a.device_ref(),  // <- reference to matrix A on device
                                     tensor_b.device_ref(),  // <- reference to matrix B on device
                                     tensor_c.device_ref(),  // <- reference to matrix C on device
                                     tensor_d.device_ref(),  // <- reference to matrix D on device
                                     {alpha, beta},          // <- tuple of alpha and beta
                                     split_k_slices};        // <- k-dimension split factor
  Gemm gemm_op;
  size_t workspace_size = Gemm::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  cudaEvent_t start, stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);

  // warmup
  cutlass::Status status;
  status = gemm_op(arguments, workspace.get());
  CUTLASS_CHECK(status);
  // time
  cudaEventRecord(start);
  for (int i=0; i<n_iter; i++) {
    status = gemm_op(arguments, workspace.get());
    CUTLASS_CHECK(status);
  }
  cudaEventRecord(stop);
  cudaEventSynchronize(stop);
  // print
  float milliseconds = 0;
  cudaEventElapsedTime(&milliseconds, start, stop);
  printf("GemmSplitKParallel: %f ms\n", milliseconds/n_iter);
}


int main() {

  cudaDeviceProp props;

  cudaError_t error = cudaGetDeviceProperties(&props, 0);
  if (error != cudaSuccess) {
    std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
    return -1;
  }

  // Define problem size
  const int length_m = 64;
  const int length_n = 64;
  const int length_k = 4096;
  const int n_iter = 100;
  run_GemmSplitKParallel(length_m, length_n, length_k, n_iter);
  run_UniversalGemmStreamK(length_m, length_n, length_k, n_iter);
  return 0;
}

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions