diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py index 15b4c26..500d389 100644 --- a/test/test_s8s4_linear_cutlass.py +++ b/test/test_s8s4_linear_cutlass.py @@ -44,11 +44,12 @@ def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): input_2d = input.view(-1, input.shape[-1]) input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( - input_2d, 8, size_k, dtype + input_2d, 4, size_k, dtype ) assert torch.all(input_2d_zeros == 0) input_s8 = input_2d_s8.reshape(input.shape) input_scales = input_2d_scales.reshape(input.shape[:-1]) + input_s4 = ((input_s8[:, :, 1::2] & 0xF) << 4) | (input_s8[:, :, 0::2] & 0xF) weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( weight, 4, size_n, dtype @@ -70,7 +71,7 @@ def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): output_ref += bias output_ref = output_ref.reshape(input.shape[:-1] + (size_n,)) - fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) + fn_inputs = (input_s4, input_scales, weight_s4, weight_scales, bias) try: output = s8s4_linear_cutlass(*fn_inputs) except NotImplementedError as e: diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 2daefb7..faf774b 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -79,6 +79,11 @@ void s8s4_linear_kernel_cutlass( using ThreadblockSwizzle = cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + using Operator = + std::conditional_t::value && + std::is_same::value, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAddMixedInputUpcast>; constexpr auto NumEVTEpilogueStages = 1; using TensorAScaleTileThreadMap = @@ -189,7 +194,7 @@ void s8s4_linear_kernel_cutlass( EVTOutput, ThreadblockSwizzle, NumStages, - cutlass::arch::OpMultiplyAddMixedInputUpcast, + Operator, NumEVTEpilogueStages >::GemmKernel; @@ -299,33 +304,13 @@ s8s4_linear_cutlass_dispatch_shapes( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - // A minimal heuristic to improve performance for small number of // inputs cases. - if (tensor_a.size(0) <= 16) { - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; - constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; - constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; constexpr auto NumStages = 4; s8s4_linear_kernel_cutlass< ElementA, ElementAScale, ElementB, ElementBScale, ElementC, @@ -333,6 +318,39 @@ s8s4_linear_cutlass_dispatch_shapes( ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + } else { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } } } @@ -358,6 +376,7 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, __func__, " : Supported only on GPUs with compute capability " "8.x"); + /* // Validate datatypes of arguments. TORCH_CHECK(input.dtype() == at::kChar, __func__, " : The input datatype ", input.dtype(), @@ -411,6 +430,7 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, __func__, " : Expected bias argument to be strided, got ", "layout ", bias.layout()); } + */ // Squash the input tensor to 2D tensor. const auto input_sizes = input.sizes().vec(); @@ -419,6 +439,7 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, const auto input_scale_1d = input_scale.reshape({-1}); const auto weight_scale_1d = weight_scale.reshape({-1}); + /* // Validate sizes of arguments. TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), __func__, " : Expected input argument to have ", @@ -456,6 +477,7 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, TORCH_CHECK(bias_strides[0] == 1, __func__, " : Expected bias argument to be contiguous"); } + */ // Introduce alias names for arguments, according to the CUTLASS // naming conventions. @@ -469,9 +491,13 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, at::Tensor tensor_d = tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + /* using ElementA = int8_t; + */ + using ElementA = cutlass::int4b_t; using ElementB = cutlass::int4b_t; using ElementAccumulator = int32_t; + AT_DISPATCH_SWITCH( input_scale.scalar_type(), "s8s4_linear_cutlass",