[CUTLASS] [CUDA] SM100 GroupMM#156203
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156203
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit f4daf74 with merge base 414ad47 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aten/src/ATen/native/cuda/GroupMM.cu
Outdated
| const bool sm10x = properties != nullptr && properties->major == 10; | ||
|
|
||
| if(sm10x){ | ||
| bf16bf16_grouped_gemm_impl_sm90_sm100< |
There was a problem hiding this comment.
this will build sm100 kernels even on sm90-only build, can you refactor it so that only the needed kernels are built? I believe ScaledMM has examples
There was a problem hiding this comment.
I think ScaledMM is also doing something similar at the top level, is there another part of the code that I should reference?
pytorch/aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Lines 947 to 973 in 3dabc35
The dispatcher here calls kernels with if constexpr but I think both path will be met during compilation
pytorch/aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Lines 731 to 742 in 3dabc35
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot label module: cuda |
|
Didn't find following labels among repository labels: module:,cuda |
|
@pytorchbot label "module: cuda" |
| } | ||
|
|
||
| static bool _scaled_mm_allowed_device(bool sm90_only=false) { | ||
| static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { |
There was a problem hiding this comment.
this part is a bit messy but it should be better after scaled grouped mm support is added for sm100 as well
There was a problem hiding this comment.
at this point I think it would be okay to pass in a set of versions or a range, this started out simple when we only really allowed 1 device but has grown
There was a problem hiding this comment.
aten/src/ATen/native/cuda/Blas.cpp
Outdated
| if (sm90_only && sm100_only){ | ||
| return dprops->major == 9 || dprops->major == 10; | ||
| } else if (sm90_only) { | ||
| return dprops->major == 9; | ||
| } else if(sm100_only){ | ||
| return dprops->major == 10; |
There was a problem hiding this comment.
| if (sm90_only && sm100_only){ | |
| return dprops->major == 9 || dprops->major == 10; | |
| } else if (sm90_only) { | |
| return dprops->major == 9; | |
| } else if(sm100_only){ | |
| return dprops->major == 10; | |
| if (sm90_only) { | |
| return dprops->major == 9; | |
| } | |
| if(sm100_only){ | |
| return dprops->major == 10; | |
| } |
will give same results and is a bit cleaner, the next branch also should be if and not else
There was a problem hiding this comment.
if the current device is sm100, wouldn't that break the check since dprops->major is false? (both sm90_only and sm100_only is true for GroupMM, the naming of "_only" is a bit dubious in this case)
anyways, I refactored the code to be like this which is a bit cleaner
if (sm90_only || sm100_only) {
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
} else {
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
}
syed-ahmed
left a comment
There was a problem hiding this comment.
Overall LGTM! Left a few nits and comments.
aten/src/ATen/native/cuda/GroupMM.cu
Outdated
| } | ||
|
|
||
| } // namespace at::cuda::detail | ||
| } // namespace at::cuda::detail No newline at end of file |
There was a problem hiding this comment.
nit: is there an extra line here?
test/test_matmul_cuda.py
Outdated
| @xfailIfSM100OrLater | ||
| @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") | ||
| @xfailIfSM120OrLater | ||
| @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90 and SM100") |
There was a problem hiding this comment.
nit: may be write - Grouped gemm supported only on SM90 and SM100
aten/src/ATen/native/cuda/GroupMM.cu
Outdated
| typename ArchTag, | ||
| bool a_row_major, | ||
| bool b_row_major, | ||
| bool Pong, |
There was a problem hiding this comment.
nit: the name of this template parameter should match the one you added above for Schedule. I think it'll improve readability.
There was a problem hiding this comment.
That is, instead of Pong, use PONGOr2SM.
aten/src/ATen/native/cuda/GroupMM.cu
Outdated
| cutlass::arch::Sm100, | ||
| a_row_major, | ||
| b_row_major, | ||
| /*2SM*/ false, |
There was a problem hiding this comment.
nit: change comment /2SM/ to /PONGOr2SM/ for readability.
aten/src/ATen/native/cuda/GroupMM.cu
Outdated
| /*2SM*/ false, | ||
| cute::_128, | ||
| cute::_256, | ||
| cute::_64>(mat_a, mat_b, offs, bias, out); |
There was a problem hiding this comment.
Where are these tile shapes from? Also how did you derive cute::_64? Just wanted to double check with the example here: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu#L127.
Also could you please do a quick nsys nvprof python my_benchmark.py and post the result to verify that indeed a 1SM or 2SM version of the cutlass kernel is being picked (for small and large respectively).
There was a problem hiding this comment.
This is the same shape as that example. In it they do Int<128 / sizeof(ElementA)> which equals to 64 since our kernel is for bf16.
I posted the nvprof but the cutlass kernel names are too long to show that it is 1sm vs 2sm so I pasted the kernel name I copied from nsight systems. The small and large are being picked properly.
There was a problem hiding this comment.
Thanks! Small request to add it as a brief comment next to the cute::_64.
|
will do a review in an hour |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 mandatory check(s) failed. The first few are:
Dig deeper by viewing the failures on hud |
|
@pytorchbot rebase |
|
You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra. |
|
@pytorchbot rebase |
|
Successfully rebased |
3c84ccf to
b39e91e
Compare
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Closes #156202
PR adds blackwell support for GroupMM
Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html
Did some preliminary benchmarking of H200 vs B200
Script
On H200
B200
nsys nvprof
The kernel names are too long to be shown via nvprof, I pasted this from nsight systems
cc @ptrblck @msaroufim @eqy @jerryzh168