Improvements for: Groupwise scaling along M for FP8 gemm#2095
Improvements for: Groupwise scaling along M for FP8 gemm#2095hwu36 merged 7 commits intoNVIDIA:mainfrom
Conversation
|
@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved. Our change is mainly: |
db87722 to
7f541db
Compare
apologies for the delay the PR has been updated, currently I am still vectorizing the loads of B scales along N (like |
There was a problem hiding this comment.
Is there any promblems when transpose A and transpose B?
There was a problem hiding this comment.
currently this assumes full tiles in N and K so if using this for inference where activations may have partial tiles if you transpose it to Y^T = WX^T it may report not implementable, I think im going to update this since ideally in vLLM we'd like to transpose it to use smaller tensor core instructions, we do lose vectorization on the loads then though
include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
Outdated
Show resolved
Hide resolved
...p8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Maybe still using ScalePromotionInterval here, and move size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{} to can_implement check?
There was a problem hiding this comment.
Hmm im not sure I see ScalePromotionInterval, what would be the motivation to not have this determined at compile time? it seems a bit unnecessarily burdensome on the user to have them set mma_promotion_interval manually
There was a problem hiding this comment.
In anycase moving this as constexpr somewhere on the top will better for readability.
static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}) and using that here?
@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?
edd90be to
2a9256f
Compare
There was a problem hiding this comment.
Why is this restriction only for M and not for N? dim-M usually maps to batch count while dim-N will be model_dimension, a nice multiple of 2? correct?
If this is A_row * B_col groupwise GEMM, it is sometimes required that we do transposed and swap creating an underlying GEMM to be B_row * A_row, swapping M <-> N. This is typically helpful for (a.) mixed-input BF16*F8 which doesn't apply here (b.) M is small say 64, we can swap and transpose to run a better tile. I have seen that to give more performance for small M.
Does vectorizing scale_copy_b vs not-vectorizing give any performance improvements? If not, I would suggest that we be symmetric for this kernel in M and N to allow user to apply swap and transpose trick to this kernel.
There was a problem hiding this comment.
I was mostly just trying to keep it as close to the original as possible to minimize the chances of perf regressions, but I agree this is much less confusing. And I think we will want to transpose in vLLM in order to use smaller instructions for smaller batch sizes.
There was a problem hiding this comment.
pushed an update that enables partial tiles in N
There was a problem hiding this comment.
can you make sure that this copy_if is issued by only 32 threads? The thread layout of shape 32 (created above) won't be tiled over entire tile by make_tiled_copy, just confirm please using simple printf
There was a problem hiding this comment.
Ran
if ((!blockIdx.x && !blockIdx.y && !blockIdx.z)) printf("%d ", threadIdx.x);
if (thread0()) printf("\n");
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
and got:
...
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
...
I think we should be good 👍
There was a problem hiding this comment.
Should TMA related tensor constructions be in lane_predicate as before, no need for all the threads to construct this even in this implementation?
There was a problem hiding this comment.
Im not sure, I didn't think this was a big deal since if you look at the 3.6.0 diff with improved the mixed input GEMM (we were told 3.6 had perf improvements for mixed input) in include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp you can see that it was updated to have all the threads compute the TMA tensors, not sure what the recommended approach is, or if this particular change had any impact. Would some guidance!
There was a problem hiding this comment.
In anycase moving this as constexpr somewhere on the top will better for readability.
static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}) and using that here?
@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1dc4ebd to
460b938
Compare
|
H100 This PR: Main: |
| } | ||
|
|
||
| if (options.k % size<2>(TileShape{}) != 0) { | ||
| std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl; |
|
Hi @LucasWilkinson , |
| } | ||
|
|
||
| if (options.k % size<2>(TileShape{}) != 0) { | ||
| std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl; |
|
Hi @LucasWilkinson , |
* fix blockwise fp8 kernels Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * wip, < 128 not working Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix < 128 Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * reduce diff Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * review comments Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * support partial n blocks Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix build errors Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
* fix blockwise fp8 kernels Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * wip, < 128 not working Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix < 128 Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * reduce diff Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * review comments Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * support partial n blocks Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix build errors Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Various improvements to "Groupwise scaling along M" (#2037) namely to address: #2087, context vllm-project/vllm#11868 (comment)
Improvements:
this PR moves to a layout of (i.e. standard M-major):
making it much easier to integrate into inference libraries
These improvements were part of vLLMs adoption of this kernel https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (PR: vllm-project/vllm#11868) and is in current wide scale use. Our goal is to rely on the CUTLASS implementation but that currently not possible given the issues above.