2626 PLATFORM_SUPPORTS_FP8 ,
2727 PLATFORM_SUPPORTS_FP8_GROUPED_GEMM ,
2828 PLATFORM_SUPPORTS_MX_GEMM ,
29+ PLATFORM_SUPPORTS_MXFP4_GEMM ,
2930 PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM ,
3031 SM100OrLater ,
3132 SM120OrLater ,
@@ -218,6 +219,7 @@ def scaled_mm_wrap(
218219 use_fast_accum = False ,
219220 bias = None ,
220221 wrap_v2 = wrap ,
222+ out = None ,
221223):
222224 if not wrap_v2 :
223225 return torch ._scaled_mm (
@@ -249,6 +251,7 @@ def scaled_mm_wrap(
249251 bias = bias ,
250252 output_dtype = out_dtype ,
251253 use_fast_accum = use_fast_accum ,
254+ out = out ,
252255 )
253256 return out
254257
@@ -706,6 +709,23 @@ def test_float8_scale(self, device) -> None:
706709 out_fp8_s = scaled_mm_wrap (x , y , scale_a = scale_a , scale_b = scale_b )
707710 self .assertEqual (out_fp8 , out_fp8_s )
708711
712+ def test_float8_out_argument (self , device ) -> None :
713+ if not _device_supports_scaled_mm_fp8 (device ):
714+ raise unittest .SkipTest (f8_msg )
715+ size = (16 , 16 )
716+ x = torch .full (size , .5 , device = device , dtype = e4m3_type )
717+ # hipblaslt does not yet support mixed e4m3_type input
718+ y_type = e4m3_type if torch .version .hip else e5m2_type
719+ y = torch .full (size , .5 , device = device , dtype = y_type ).t ()
720+
721+ out = torch .empty (size , device = device , dtype = torch .bfloat16 )
722+
723+ scale_one = torch .tensor (1.0 , device = device )
724+ out_fp8 = scaled_mm_wrap (x , y , scale_a = scale_one , scale_b = scale_one , out = out )
725+
726+ if out_fp8 .data_ptr () != out .data_ptr ():
727+ raise AssertionError ("out_fp8 and out must have the same data pointers" )
728+
709729
710730 @unittest .skipIf (not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM , mxfp8_grouped_mm_skip_msg )
711731 @parametrize ("G" , [1 , 4 , 16 ])
@@ -716,9 +736,12 @@ def test_float8_scale(self, device) -> None:
716736 def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d (self , G , M , N , K , format ):
717737 torch .manual_seed (42 )
718738
719- if format == "mxfp4" and SM120OrLater :
739+ if ( format == "mxfp4" ) and SM120OrLater and ( not PLATFORM_SUPPORTS_MXFP4_GEMM ) :
720740 raise unittest .SkipTest ("MXFP4 on CUDA only supported on B200/B300" )
721741
742+ if (format == "mxfp4" ) and (not PLATFORM_SUPPORTS_MXFP4_GEMM ):
743+ raise unittest .SkipTest ("MXFP4 not supported on this platform - build with MSLK support" )
744+
722745 total_K = K # Alias for clarity, communicating this consists of several groups along this dim
723746 input_group_end_offsets = generate_jagged_offs (
724747 G , total_K , multiple_of = 32 , device = "cuda"
@@ -786,8 +809,10 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
786809 def test_mxfp8_scaled_grouped_mm_2d_3d (self , G , M , N , K , format ):
787810 torch .manual_seed (42 )
788811
789- if format == "mxfp4" and SM120OrLater :
812+ if ( format == "mxfp4" ) and SM120OrLater :
790813 raise unittest .SkipTest ("MXFP4 on CUDA only supported on B200/B300" )
814+ if (format == "mxfp4" ) and (not PLATFORM_SUPPORTS_MXFP4_GEMM ):
815+ raise unittest .SkipTest ("MXFP4 not supported on this platform - build with MSLK support" )
791816
792817 # Simulate 2d-3d grouped gemm `out = input @ weight.t()`
793818 # 2D inputs with groups along M, 3D weights.
@@ -1894,7 +1919,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum,
18941919 raise unittest .SkipTest ("nvfp4 not supported on ROCm, skipping" )
18951920 if (recipe == "nvfp4" or recipe == "mxfp4" ) and fast_accum :
18961921 raise unittest .SkipTest ("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping" )
1897- if recipe == "mxfp4" and SM120OrLater :
1922+ if ( recipe == "mxfp4" ) and SM120OrLater or ( not PLATFORM_SUPPORTS_MXFP4_GEMM ) :
18981923 raise unittest .SkipTest ("MXFP4 on CUDA only supported on B200/B300" )
18991924
19001925 device = "cuda"
0 commit comments