-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Closed
Labels
Description
Hello,
I read this issue:
kernel::GemmUniversalwith modeGemmUniversalMode::kGemmSplitKParallelwill be equivalent tokernel::GemmSplitKParallel. The difference comes to fore for thedevice::-scoped kernels, whereindevice::GemmSplitKParallelcalls a reduction kernel anddevice::GemmUniversaldoes not. However, it is recommended that you usedevice::GemmUniversalrather thandevice::GemmSplitKParallel, as the former is more-frequently tested.
Originally posted by @jackkosaian in #702 (comment)
However, I tested two implementations and found that GemmUniversal is much slower than GemmSplitKParallel when M and N are small and K is large, for example, M=64, N=64, K=4096.
GemmSplitKParallel: 0.011651 ms
UniversalGemmStreamK: 0.083712 ms
How could I configure the GemmUniversal to reproduce the speed of GemmSplitKParallel for M=64, N=64, K=4096? Thanks!
I profile GEMMs with Cutlass v3.4.1 on an A5000 GPU.
Here is my testing code.
#include <iostream>
#include "cuda_runtime.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu
void run_UniversalGemmStreamK(int m, int n, int k, int n_iter) {
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(m, n, k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm80;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
constexpr int NumStages = 4;
// StreamK device GEMM implementation type
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
NumStages,
128 / cutlass::sizeof_bits<ElementInputA>::value,
128 / cutlass::sizeof_bits<ElementInputB>::value>;
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
typename DeviceGemmStreamK::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // kGemmSplitKParallel mode
problem_size, // problem_size
16, // batch count / splitk slices
{alpha, beta}, // epilogue parameters
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c.device_data(), // ptr_C
tensor_d.device_data(), // ptr_D
problem_size.mk().product(), // batch_stride_A
problem_size.nk().product(), // batch_stride_B
problem_size.mn().product(), // batch_stride_C
problem_size.mn().product(), // batch_stride_D
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c.layout().stride(0), // stride_c
tensor_d.layout().stride(0), // stride_d
0};
DeviceGemmStreamK gemm_op;
size_t workspace_size = DeviceGemmStreamK::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
// warmup
cutlass::Status status;
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
// time
cudaEventRecord(start);
for (int i=0; i<n_iter; i++) {
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaEventSynchronize(stop);
// print
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("UniversalGemmStreamK: %f ms\n", milliseconds/n_iter);
}
// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/06_splitK_gemm/splitk_gemm.cu
void run_GemmSplitKParallel(int m, int n, int k, int n_iter) {
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(m, n, k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm80;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp>;
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// split K dimension into 16 partitions
int split_k_slices = 16;
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
// warmup
cutlass::Status status;
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
// time
cudaEventRecord(start);
for (int i=0; i<n_iter; i++) {
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaEventSynchronize(stop);
// print
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("GemmSplitKParallel: %f ms\n", milliseconds/n_iter);
}
int main() {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
// Define problem size
const int length_m = 64;
const int length_n = 64;
const int length_k = 4096;
const int n_iter = 100;
run_GemmSplitKParallel(length_m, length_n, length_k, n_iter);
run_UniversalGemmStreamK(length_m, length_n, length_k, n_iter);
return 0;
}
Reactions are currently unavailable