Add support for mixed 4-bit/8-bit data types GEMM#1413
Conversation
|
More to come here: support for |
|
Added more tests. |
There was a problem hiding this comment.
Yep. But it's there pretty much in all the tests in test/unit/gemm/device, it seems we've been just copying it around... Would be the best to remove all of them in a separate PR.
|
Added generator support for S8/S4 and S4/S8. AFAIK, implementing generator support for given operation is not specifically documented, so I want to clarify the steps I've taken here. Basically, I've copied code from
I did the verification as @manishucsd suggested here: As mentioned above, I did the build with all the relevant kernels included, and then I verified that Overall, this PR now contains everything that I intended to do for |
|
Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization? |
This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator. |
|
@manishucsd, @hwu36: Would it be possible for someone to review this PR (and eventually #1350 too)? These should not be controversial, are needed by PyTorch, and for this one I'd like to proceed with another PR to add other 4-bit/8-bit integer combinations that make sense. |
|
working on it now. |
Great job! How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic |
The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon. |
I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. |
These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see |
|
On a quick look, your strides may be wrong. |
Thank you for your prompt reply. I don't know much about this parameter, and I can't find many references. Could you give me some more details? Thank you very much. |
I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually? |
No, s4 values should be packed, two values per byte. |
Thanks for your help ! I can get correct result now. but I have another question: |
If matrix |
Thanks, I’m trying this, but it’s not going well currently. That is: |
Well, that's not element-wise multiplication with |
I got errors with error message:
The complete log is here. |
Are you using CUTLASS main, or the branch from this PR? |
I'm using this PR branch: // ok
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA, // ElementA
cutlass::layout::RowMajor, // LayoutA
ElementB, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
ElementOutput, // ElementOutput
cutlass::layout::RowMajor, // LayoutOutput
ElementAccumulator, // ElementAccumulator
cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
cutlass::arch::Sm80, // tag indicating target GPU compute architecture
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
32, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
But when I try to use // get errors
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel; // where is the key I think
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;I don’t know much about |
Can you post your full code here? |
This comment was marked as duplicate.
This comment was marked as duplicate.
|
This code uses PyTorch, can you post a reproducible example that uses CUTLASS only? |
Hi @alexsamardzic , I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu#L114 you can add this example , then complie and run it: when But when I am really at a loss and would greatly appreciate any guidance or help you can provide. Thank you very much in advance for your time and assistance! |
Replace |
Works for me. Thanks! |
Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work. |
Yea, that was a mistake. |
Going through relevant examples, as well as unit tests, in the CUTLASS source tree is probably still the best way to start. |
|
Hi @alexsamardzic , in fact, I am working in the GEMM+de-quantization fusion kernel for W4A8 based on your PR, similar to the W8A8 kernel for pytorch here (GEMM+de-quantization), which also used EVT. I have completed most of the work and used EVT to finish the de-quantization. Can you please have a look what the possible issues might be? |
I'm sorry, but this has nothing to do with this particular PR, and unfortunately I don't have cycles to help you with this. You need to be sure that you understand building EVT epilogues, as well as specifying corresponding arguments. Then, in your position I would start with a simple epilogue that is just storing values from the accumulator into the output tensor. If results match expected ones, then I would add nodes into the EVT epilogue that do the multiplication, one by one, and would keep comparing results with the expected ones. When there is mismatch, you should know where to look for the fix. |
|
Hi @alexsamardzic ,thanks for your help, I have make it done. When profiling a single GEMM, do you think the performance of |
The GEMM actually performed is the same: |
manishucsd
left a comment
There was a problem hiding this comment.
Thank you for working on this. Apologies for a delayed review. LGTM.
Over to NVIDIA/CUTLASS (cc: @hwu36 ) for merging this.
There was a problem hiding this comment.
How did you come up with this TileDescription list for S8 x S4? I guess you carried these from S8 x S8. Please make sure all of these pass verification. You can follow steps similar to here to instantiate all the tile shapes listed here by using -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8". By default the build process only instantiate 128x128 tile shape.
There was a problem hiding this comment.
Can you please run the verification and profiling on --m=3456 --n=4096 --k=2048 on an A100?
Please compile using -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8,s4,s8" to also have s4 x s4 and s8 x s8 kernels in the runs.
There was a problem hiding this comment.
The tiles selection is desribed in a comment above; also, as mentioned in this comment, I did the verification. I will repeat the verification procedure, together with profiling, and report the outcome here.
There was a problem hiding this comment.
There was a build issue after rebasing on the latest main: basically, OpMultiplyAddSaturate for MmaTensorOpPolicy in the specialization of struct DefaultMmaTensorOp (in include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h) seem to be obligatory now, as the build fails if OpMultiplyAdd used. The branch is updated accordingly.
I've configured the build using -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8,s4,s8" CMake options, and then verified that cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 unit test passes. Then, I did profiler runs as follows:
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s8:row --B=s8:column >& s8_s8.txt
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s4:row --B=s8:column >& s4_s8.txt
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s8:row --B=s4:column >& s8_s4.txt
The corresponding profiler outputs are here:
s8_s8.txt
s4_s8.txt
s8_s4.txt
The disposition values for mixed data types cases with s8 accumulator are still incorrect. Also, the timings are somewhat slower than for corresponding s8xs8 cases (with the same configurations: tile sizes etc.).
There was a problem hiding this comment.
Thank you for running and sharing these results.
Accumulator is for all of these runs should be S32 as shown at the bottom of the output in csv format with accum type = S32. The Incorrect disposition with mixed-input is happening for only S8 output, i.e., when the accumulators are S32 but the output is downcast-ed to S8.
We do not see incorrect results for S8xS8 with S32 accumulators and S8 output, can you pick one row of incorrect run from
(elementD/elementC type) <= (elementA type) x(elementB type) + (accum type)
S8 <= S8 x S4 + S32
and compare the same kernel configuration against
S8 <= S8 x S8 + S32
to find where is the difference?
I believe it is to do with initialization of the operands during profiling or inside the kernel epilogue S32-to-S8. quantization.
Also, you can just upload the csv that can be produced by adding --output=filename.csv to the profiler runs
There was a problem hiding this comment.
Here is what I found so far regarding incorrect cases:
First, I made following change in the code generating inputs, in order to generate the same inputs for profiler for S4xS8 and S8xS8 cases:
diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu
index 2cbfa5d2..7b488fe8 100644
--- a/tools/profiler/src/device_context.cu
+++ b/tools/profiler/src/device_context.cu
@@ -105,7 +105,7 @@ DeviceAllocation *DeviceContext::allocate_tensor(
data_distribution.set_uniform(-1, 1, 0);
break;
case library::NumericTypeID::kS4:
- data_distribution.set_uniform(-2, 2, 0);
+ data_distribution.set_uniform(-3, 3, 0);
break;
case library::NumericTypeID::kU2:
data_distribution.set_uniform(0, 2, 0);
I used following profiler runs to make comparision between S4xS8 and S8xS8 cases (BTW, I found that smaller input shapes selection that would still allow for reproducing the problem would be --m=32 --n=64 --k=512):
cutlass_profiler --operation=gemm --gemm_kind=universal --m=3456 --n=4096 --k=2048 --A=s8:row --B=s8:column --C=s8:column --D=s8:column --accum=s32 --cta_m=256 --cta_n=128 --cta_k=64 --stages=3 --save-workspace=always
cutlass_profiler --operation=gemm --gemm_kind=universal --m=3456 --n=4096 --k=2048 --A=s4:row --B=s8:column --C=s8:column --D=s8:column --accum=s32 --cta_m=256 --cta_n=128 --cta_k=64 --stages=3 --save-workspace=always
By comparing saved .mat files, I verified that input matrices A, B and C are the same, but also that output matrices D are the same. What differs are actually Reference matrices, which means that reference results calculated for S4xS8 case are wrong. If I understood it correctly, cuBLAS is used for reference calculations, so I'll check what's going on there...
There was a problem hiding this comment.
I am not sure if cuBLAS is called for this for reference check. The output should show which references are called. You can use --verification-providers=cublas,host,device to run them all. Is there a host reference for this you must check in here /tools/library/src/reference ?
There was a problem hiding this comment.
Indeed - device provider is actually used for reference check here. I posted an update with the fix in reference calculations, so for most of cases with S8 output, cutlass_profiler reports success now. However, there are still couple of cases where incorrect is reported, I'm looking into this...
There was a problem hiding this comment.
Pushed another update - the problem with remaining incorrect cases was that I haven't copied C operand alignment update from S8xS8 case, in the generator code. Everything is reported as passed now by profiler, the output files are attached below. I believe this one should be ready for merging now.
There was a problem hiding this comment.
@alexsamardzic thank you for digging it through. LGTM!
@hwu36 , @thakkarV , @IonThruster , can you please help it merge it?
|
while we are at this, i think we can improve the int4->int8 upcasting. now we use 11 instructions to upcast 8 elements. quite a lot. we used a look-up-table method to do int->fp8 upcasting (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/numeric_conversion.h#L2983-L3027), I think we maybe able to use the same here. @alexsamardzic , do you want to give it a try? i am setting up now so it won't take me months to merge your code. cc @rhenry-nv |
Sure. Below is a patch to implement the look-up table method for int4->int8 (pretty much the same as existing int4->fp8 code), and also the profiler outputs for original and patched version. It seems that the look-up table method is slower. I ran the profiler in both cases as follows: and here are mentioned files: I was the least happy about the conversion code in this PR, but this is the best I was able to come up with... |
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>


No description provided.