Skip to content

[CUTLASS] [CUDA] SM100 GroupMM#156203

Closed
AaronWang04 wants to merge 32 commits intopytorch:mainfrom
AaronWang04:sm100groupgemm
Closed

[CUTLASS] [CUDA] SM100 GroupMM#156203
AaronWang04 wants to merge 32 commits intopytorch:mainfrom
AaronWang04:sm100groupgemm

Conversation

@AaronWang04
Copy link
Contributor

@AaronWang04 AaronWang04 commented Jun 17, 2025

Closes #156202

PR adds blackwell support for GroupMM

Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html

Did some preliminary benchmarking of H200 vs B200

Script

import torch
print(torch.__file__)
device = torch.device("cuda")
dtype = torch.bfloat16

shapes = [
    (16, 128000, 7168, 7168),
    (128, 1, 2048, 7168)
]

for batch, M, N, K in shapes:
    a = torch.randn(batch, M, K, device=device, dtype=dtype)
    b = torch.randn(batch, N, K, device=device, dtype=dtype)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    for i in range(5): c = torch._grouped_mm(a, b)

    num_iter = 50
    start_event.record()

    for i in range(num_iter): c = torch._grouped_mm(a, b)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iter
    print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}")
    print(f"Time per Iteration:\t {avg_time_ms:.4f} ms")

On H200

batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 298.6668 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 4.1462 ms

B200

batch: 16       M: 128000       N: 7168 K: 7168
Time per Iteration:      190.7458 ms
batch: 128      M: 1    N: 2048 K: 7168
Time per Iteration:      3.0680 ms

nsys nvprof

root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py
WARNING: python and any of its children processes will be profiled.

Collecting data...
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 192.6420 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 1.2255 ms
Generating '/tmp/nsys-report-6a53.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)    StdDev (ns)                 Name               
 --------  ---------------  ---------  ------------  ------------  --------  -----------  ------------  ---------------------------------
     98.9      10586895744          2  5293447872.0  5293447872.0  73786464  10513109280  7381715954.2  cudaDeviceSynchronize            
      1.0        104084608          5    20816921.6    33552480.0    100800     34786208    18048125.3  cudaMalloc                       
      0.1          5694304          4     1423576.0     1416656.0   1258560      1602432      181668.1  cudaGetDeviceProperties_v2_v12000
      0.1          5430496        130       41773.0        4560.0      2496      3854368      345761.8  cudaLaunchKernel                 
      0.0           587584        110        5341.7        4992.0      4224        16992        1482.0  cudaLaunchKernelExC_v11060       
      0.0           119200        660         180.6         128.0        96         4128         206.7  cudaGetDriverEntryPoint_v11030   
      0.0            68352        660         103.6          64.0        32         4928         224.6  cuTensorMapEncodeTiled           
      0.0            34976         49         713.8         224.0       160         6720        1343.4  cudaStreamIsCapturing_v10000     
      0.0            32992          4        8248.0        7456.0      4128        13952        4804.4  cudaEventRecord                  
      0.0            16928          4        4232.0        3600.0      1728         8000        2764.7  cudaEventQuery                   
      0.0            16288          4        4072.0        3568.0      1952         7200        2396.1  cudaEventCreateWithFlags         
      0.0            13632          4        3408.0        2672.0       544         7744        3408.7  cudaEventDestroy                 
      0.0             1056          1        1056.0        1056.0      1056         1056           0.0  cuModuleGetLoadingMode           

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     99.0      10549232845         55  191804233.5  192944479.0  165746368  203645313    5353204.3  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.6         67327135         55    1224129.7    1330656.0     924320    1364928     182180.4  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.3         34854783         20    1742739.1    1597856.0      10080    3899616     818421.2  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.0           354880        110       3226.2       3296.0       1920       4160        554.4  void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:…

The kernel names are too long to be shown via nvprof, I pasted this from nsight systems

small kernel 1SM
100.0%	1.286 ms	1	1.286 ms	1.286 ms	1.286 ms	1.286 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)

large kernel 2SM
100.0%	194.178 ms	1	194.178 ms	194.178 ms	194.178 ms	194.178 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)

cc @ptrblck @msaroufim @eqy @jerryzh168

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156203

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit f4daf74 with merge base 414ad47 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

const bool sm10x = properties != nullptr && properties->major == 10;

if(sm10x){
bf16bf16_grouped_gemm_impl_sm90_sm100<
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will build sm100 kernels even on sm90-only build, can you refactor it so that only the needed kernels are built? I believe ScaledMM has examples

Copy link
Contributor Author

@AaronWang04 AaronWang04 Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ScaledMM is also doing something similar at the top level, is there another part of the code that I should reference?

const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
const bool sm9x = properties != nullptr && properties->major == 9;
const bool sm10x = properties != nullptr && properties->major == 10;
const bool sm12x = properties != nullptr && properties->major == 12;
if (!(sm89 || sm9x || sm10x || sm12x)) {
TORCH_CHECK(
false, "Rowwise scaling is not currently supported on your device");
}
if (sm9x) {
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
/*ArchTag=*/cutlass::arch::Sm90,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (sm10x) {
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
/*ArchTag=*/cutlass::arch::Sm100,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (sm12x) {
// sm12x doesn't have multicast feature
handle_transposition<
/*ClusterShape=*/cute::Shape<cute::_1, cute::_1, cute::_1>,
/*Transposed=*/std::false_type,
/*ArchTag=*/cutlass::arch::Sm120,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
dispatch_fp8_rowwise_kernel_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
}

The dispatcher here calls kernels with if constexpr but I think both path will be met during compilation

if constexpr (std::is_same_v<ArchTag, cutlass::arch::Sm90>) {
return f8f8bf16_rowwise_impl<
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
ClusterShape,
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
} else {
return f8f8bf16_rowwise_impl_sm100_sm120<
ArchTag,
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
ClusterShape,
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
}

@AaronWang04
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jun 18, 2025
@AaronWang04
Copy link
Contributor Author

@pytorchbot label module: cuda

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 18, 2025

Didn't find following labels among repository labels: module:,cuda

@AaronWang04
Copy link
Contributor Author

@pytorchbot label "module: cuda"

@pytorch-bot pytorch-bot bot added the module: cuda Related to torch.cuda, and CUDA support in general label Jun 18, 2025
}

static bool _scaled_mm_allowed_device(bool sm90_only=false) {
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is a bit messy but it should be better after scaled grouped mm support is added for sm100 as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point I think it would be okay to pass in a set of versions or a range, this started out simple when we only really allowed 1 device but has grown

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eqy eqy marked this pull request as ready for review June 18, 2025 23:24
@eqy eqy requested review from eqy and syed-ahmed as code owners June 18, 2025 23:24
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 20, 2025
Comment on lines +1082 to +1087
if (sm90_only && sm100_only){
return dprops->major == 9 || dprops->major == 10;
} else if (sm90_only) {
return dprops->major == 9;
} else if(sm100_only){
return dprops->major == 10;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (sm90_only && sm100_only){
return dprops->major == 9 || dprops->major == 10;
} else if (sm90_only) {
return dprops->major == 9;
} else if(sm100_only){
return dprops->major == 10;
if (sm90_only) {
return dprops->major == 9;
}
if(sm100_only){
return dprops->major == 10;
}

will give same results and is a bit cleaner, the next branch also should be if and not else

Copy link
Contributor Author

@AaronWang04 AaronWang04 Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the current device is sm100, wouldn't that break the check since dprops->major is false? (both sm90_only and sm100_only is true for GroupMM, the naming of "_only" is a bit dubious in this case)

anyways, I refactored the code to be like this which is a bit cleaner

    if (sm90_only || sm100_only) {
      return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
    } else {
      return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
    }

Copy link
Collaborator

@syed-ahmed syed-ahmed left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! Left a few nits and comments.

}

} // namespace at::cuda::detail
} // namespace at::cuda::detail No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there an extra line here?

@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@xfailIfSM120OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90 and SM100")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: may be write - Grouped gemm supported only on SM90 and SM100

typename ArchTag,
bool a_row_major,
bool b_row_major,
bool Pong,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the name of this template parameter should match the one you added above for Schedule. I think it'll improve readability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, instead of Pong, use PONGOr2SM.

cutlass::arch::Sm100,
a_row_major,
b_row_major,
/*2SM*/ false,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change comment /2SM/ to /PONGOr2SM/ for readability.

/*2SM*/ false,
cute::_128,
cute::_256,
cute::_64>(mat_a, mat_b, offs, bias, out);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these tile shapes from? Also how did you derive cute::_64? Just wanted to double check with the example here: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu#L127.

Also could you please do a quick nsys nvprof python my_benchmark.py and post the result to verify that indeed a 1SM or 2SM version of the cutlass kernel is being picked (for small and large respectively).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same shape as that example. In it they do Int<128 / sizeof(ElementA)> which equals to 64 since our kernel is for bf16.

I posted the nvprof but the cutlass kernel names are too long to show that it is 1sm vs 2sm so I pasted the kernel name I copied from nsight systems. The small and large are being picked properly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Small request to add it as a brief comment next to the cute::_64.

@drisspg drisspg self-requested a review June 25, 2025 05:43
@drisspg
Copy link
Contributor

drisspg commented Jun 26, 2025

will do a review in an hour

@AaronWang04
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 27, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@AaronWang04
Copy link
Contributor Author

@pytorchbot rebase

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 27, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@syed-ahmed
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

Successfully rebased sm100groupgemm onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout sm100groupgemm && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jun 27, 2025
@eqy
Copy link
Collaborator

eqy commented Jun 28, 2025

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Upgrade torch._grouped_mm to SM100+

8 participants