@@ -315,7 +315,7 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist)
315315 def test_grouped_gemm_2d_2d (self , strided , a_row_major , b_row_major , use_torch_compile ):
316316 device = "cuda"
317317 dtype = torch .bfloat16
318- m , n , k , n_groups = 16 , 32 , 64 , 4
318+ m , n , k , n_groups = 16 , 32 , 64 , 4 # all sizes have to be divisible by 16
319319 if a_row_major :
320320 a = torch .randn (m , k * n_groups + k * int (strided ), device = device , dtype = dtype )[:, :k * n_groups ]
321321 else :
@@ -382,9 +382,6 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, use_torch_c
382382 b_contig = b if b_row_major else b .transpose (- 2 , - 1 )
383383 self .assertTrue (b_contig .is_contiguous () is not strided )
384384 for check_zero_size in (False , True ):
385- if check_zero_size and n_groups <= 1 :
386- continue
387-
388385 a .grad = None
389386 b .grad = None
390387 offs = torch .arange (m , n_groups * m + 1 , m , device = "cuda" , dtype = torch .int32 )
@@ -487,9 +484,6 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, use_torch_c
487484 b_contig = b if b_row_major else b .transpose (- 2 , - 1 )
488485 self .assertTrue (b_contig .is_contiguous () is not strided )
489486 for check_zero_size in (False , True ):
490- if check_zero_size and n_groups <= 1 :
491- continue
492-
493487 offs = torch .arange (n , n_groups * n + 1 , n , device = "cuda" , dtype = torch .int32 )
494488 if check_zero_size :
495489 offs [0 ] = offs [1 ]
@@ -1651,27 +1645,17 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
16511645 for a , b , ascale , bscale , out in zip (alist , blist , ascalelist , bscalelist , outlist ):
16521646 out_ref = torch ._scaled_mm (a , b .t (), ascale .view (- 1 , 1 ), bscale .view (1 , - 1 ),
16531647 out_dtype = torch .bfloat16 , use_fast_accum = use_fast_accum )
1654- self .assertEqual (out , out_ref , atol = 5e-2 , rtol = 5e-4 )
1655-
1656- # Testing only _scaled_grouped_mm() with multiple shapes, as
1657- # _scaled_mm() already has more combinations of parameters than
1658- # _scaled_grouped_mm(), for supporing more than one inputs layout
1659- # combinations.
1648+ self .assertEqual (out , out_ref , atol = 8e-2 , rtol = 8e-4 )
16601649
16611650 @unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
16621651 @xfailIfSM100OrLater
16631652 @unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1664- @parametrize (
1665- "n_groups, m, n, k" ,
1666- [(2 , 1 , 16 , 16 ),
1667- (4 , 16 , 16 , 16 )],
1668- name_fn = lambda n_groups , m , n , k : f"{ n_groups } _{ m } _{ n } _{ k } " ,
1669- )
16701653 @parametrize ("fast_accum" , [False , True ])
16711654 @parametrize ("strided" , [False , True ])
16721655 @parametrize ("use_torch_compile" , [False , True ])
1673- def test_scaled_grouped_gemm_2d_2d (self , n_groups , m , n , k , fast_accum , strided , use_torch_compile ):
1656+ def test_scaled_grouped_gemm_2d_2d (self , fast_accum , strided , use_torch_compile ):
16741657 device = "cuda"
1658+ m , n , k , n_groups = 16 , 32 , 64 , 4 # all sizes have to be divisible by 16
16751659 a = torch .randn (m , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
16761660 b = torch .randn (n , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
16771661 scale_a = torch .rand (m * n_groups , device = device , dtype = torch .float32 )
@@ -1701,26 +1685,18 @@ def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided,
17011685 @unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
17021686 @xfailIfSM100OrLater
17031687 @unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1704- @parametrize (
1705- "n_groups, m, n, k" ,
1706- [(2 , 1 , 16 , 16 ),
1707- (4 , 16 , 16 , 16 )],
1708- name_fn = lambda n_groups , m , n , k : f"{ n_groups } _{ m } _{ n } _{ k } " ,
1709- )
17101688 @parametrize ("fast_accum" , [False , True ])
17111689 @parametrize ("strided" , [False , True ])
17121690 @parametrize ("use_torch_compile" , [False , True ])
1713- def test_scaled_grouped_gemm_2d_3d (self , n_groups , m , n , k , fast_accum , strided , use_torch_compile ):
1691+ def test_scaled_grouped_gemm_2d_3d (self , fast_accum , strided , use_torch_compile ):
17141692 device = "cuda"
17151693 s_int = int (strided )
1694+ m , n , k , n_groups = 16 , 32 , 64 , 4
17161695 a = torch .randn (m * n_groups , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[:, :k ]
17171696 b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
17181697 self .assertTrue (a .is_contiguous () is not strided )
17191698 self .assertTrue (b .is_contiguous () is not strided )
17201699 for check_zero_size in (True , False ):
1721- if check_zero_size and n_groups <= 1 :
1722- continue
1723-
17241700 offs = torch .arange (m , n_groups * m + 1 , m , device = "cuda" , dtype = torch .int32 )
17251701 if check_zero_size :
17261702 offs [0 ] = offs [1 ]
@@ -1751,18 +1727,13 @@ def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided,
17511727 @unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
17521728 @xfailIfSM100OrLater
17531729 @unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1754- @parametrize (
1755- "n_groups, m, n, k" ,
1756- [(2 , 1 , 16 , 16 ),
1757- (4 , 16 , 16 , 16 )],
1758- name_fn = lambda n_groups , m , n , k : f"{ n_groups } _{ m } _{ n } _{ k } " ,
1759- )
17601730 @parametrize ("fast_accum" , [False , True ])
17611731 @parametrize ("strided" , [False , True ])
17621732 @parametrize ("use_torch_compile" , [False , True ])
1763- def test_scaled_grouped_gemm_3d_3d (self , n_groups , m , n , k , fast_accum , strided , use_torch_compile ):
1733+ def test_scaled_grouped_gemm_3d_3d (self , fast_accum , strided , use_torch_compile ):
17641734 device = "cuda"
17651735 s_int = int (strided )
1736+ m , n , k , n_groups = 16 , 32 , 64 , 4
17661737 a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
17671738 b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
17681739 self .assertTrue (a .is_contiguous () is not strided )
@@ -1786,28 +1757,20 @@ def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided,
17861757 @unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
17871758 @xfailIfSM100OrLater
17881759 @unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1789- @parametrize (
1790- "n_groups, m, n, k" ,
1791- [(2 , 1 , 16 , 16 ),
1792- (4 , 16 , 16 , 16 )],
1793- name_fn = lambda n_groups , m , n , k : f"{ n_groups } _{ m } _{ n } _{ k } " ,
1794- )
17951760 @parametrize ("fast_accum" , [False , True ])
17961761 @parametrize ("strided" , [False , True ])
17971762 @parametrize ("use_torch_compile" , [False , True ])
1798- def test_scaled_grouped_gemm_3d_2d (self , n_groups , m , n , k , fast_accum , strided , use_torch_compile ):
1763+ def test_scaled_grouped_gemm_3d_2d (self , fast_accum , strided , use_torch_compile ):
17991764 device = "cuda"
18001765 s_int = int (strided )
1766+ m , n , k , n_groups = 16 , 32 , 64 , 4
18011767 a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
18021768 b = torch .randn (n * n_groups , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[:, :k ]
18031769 self .assertTrue (a .is_contiguous () is not strided )
18041770 self .assertTrue (b .is_contiguous () is not strided )
18051771 scale_a = torch .rand (n_groups * m , device = "cuda" , dtype = torch .float32 ).view (n_groups , m )
18061772 scale_b = torch .rand (n_groups * n , device = "cuda" , dtype = torch .float32 )
18071773 for check_zero_size in (True , False ):
1808- if check_zero_size and n_groups <= 1 :
1809- continue
1810-
18111774 offs = torch .arange (n , n_groups * n + 1 , n , device = "cuda" , dtype = torch .int32 )
18121775 if check_zero_size :
18131776 offs [0 ] = offs [1 ]
0 commit comments