Skip to content

Commit 4c4df17

Browse files
slayton58pytorchmergebot
authored andcommitted
Add optional out argument to F.scaled_mm (#174395)
Summary: * Add `out=` argument to `F.scaled_mm` and basic test * Properly guard MXFP4 tests where the build has CUDA but not MSLK. Gracefully refuse to run instead of hard-failing. Test Plan: ``` pytest -v -k "test_float8_out_argument" test/test_scaled_matmul_cuda.py ``` Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: #174395 Approved by: https://github.com/danielvegamyhre
1 parent d0ea7fa commit 4c4df17

4 files changed

Lines changed: 43 additions & 3 deletions

File tree

test/test_scaled_matmul_cuda.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
PLATFORM_SUPPORTS_FP8,
2727
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
2828
PLATFORM_SUPPORTS_MX_GEMM,
29+
PLATFORM_SUPPORTS_MXFP4_GEMM,
2930
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
3031
SM100OrLater,
3132
SM120OrLater,
@@ -218,6 +219,7 @@ def scaled_mm_wrap(
218219
use_fast_accum=False,
219220
bias=None,
220221
wrap_v2=wrap,
222+
out=None,
221223
):
222224
if not wrap_v2:
223225
return torch._scaled_mm(
@@ -249,6 +251,7 @@ def scaled_mm_wrap(
249251
bias=bias,
250252
output_dtype=out_dtype,
251253
use_fast_accum=use_fast_accum,
254+
out=out,
252255
)
253256
return out
254257

@@ -706,6 +709,23 @@ def test_float8_scale(self, device) -> None:
706709
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
707710
self.assertEqual(out_fp8, out_fp8_s)
708711

712+
def test_float8_out_argument(self, device) -> None:
713+
if not _device_supports_scaled_mm_fp8(device):
714+
raise unittest.SkipTest(f8_msg)
715+
size = (16, 16)
716+
x = torch.full(size, .5, device=device, dtype=e4m3_type)
717+
# hipblaslt does not yet support mixed e4m3_type input
718+
y_type = e4m3_type if torch.version.hip else e5m2_type
719+
y = torch.full(size, .5, device=device, dtype=y_type).t()
720+
721+
out = torch.empty(size, device=device, dtype=torch.bfloat16)
722+
723+
scale_one = torch.tensor(1.0, device=device)
724+
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_one, scale_b=scale_one, out=out)
725+
726+
if out_fp8.data_ptr() != out.data_ptr():
727+
raise AssertionError("out_fp8 and out must have the same data pointers")
728+
709729

710730
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
711731
@parametrize("G", [1, 4, 16])
@@ -716,9 +736,12 @@ def test_float8_scale(self, device) -> None:
716736
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
717737
torch.manual_seed(42)
718738

719-
if format == "mxfp4" and SM120OrLater:
739+
if (format == "mxfp4") and SM120OrLater and (not PLATFORM_SUPPORTS_MXFP4_GEMM):
720740
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300")
721741

742+
if (format == "mxfp4") and (not PLATFORM_SUPPORTS_MXFP4_GEMM):
743+
raise unittest.SkipTest("MXFP4 not supported on this platform - build with MSLK support")
744+
722745
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
723746
input_group_end_offsets = generate_jagged_offs(
724747
G, total_K, multiple_of=32, device="cuda"
@@ -786,8 +809,10 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
786809
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
787810
torch.manual_seed(42)
788811

789-
if format == "mxfp4" and SM120OrLater:
812+
if (format == "mxfp4") and SM120OrLater:
790813
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300")
814+
if (format == "mxfp4") and (not PLATFORM_SUPPORTS_MXFP4_GEMM):
815+
raise unittest.SkipTest("MXFP4 not supported on this platform - build with MSLK support")
791816

792817
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
793818
# 2D inputs with groups along M, 3D weights.
@@ -1894,7 +1919,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum,
18941919
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
18951920
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
18961921
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
1897-
if recipe == "mxfp4" and SM120OrLater:
1922+
if (recipe == "mxfp4") and SM120OrLater or (not PLATFORM_SUPPORTS_MXFP4_GEMM):
18981923
raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300")
18991924

19001925
device = "cuda"

torch/nn/functional.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6755,6 +6755,8 @@ def scaled_mm(
67556755
output_dtype: torch.dtype | None = torch.bfloat16,
67566756
contraction_dim: list[int] | tuple[int, ...] = (),
67576757
use_fast_accum: bool = False,
6758+
*,
6759+
out: Optional[torch.Tensor] = None,
67586760
) -> Tensor:
67596761
r"""
67606762
scaled_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, output_dtype,
@@ -6774,6 +6776,7 @@ def scaled_mm(
67746776
output_dtype: dtype used for the output tensor
67756777
contraction_dim: describe which dimensions are :math:`K` in the matmul.
67766778
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
6779+
out: User-provided output tensor
67776780
"""
67786781

67796782
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
@@ -6821,6 +6824,7 @@ def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
68216824
output_dtype,
68226825
contraction_dim,
68236826
use_fast_accum,
6827+
out=out,
68246828
)
68256829

68266830
return out

torch/nn/functional.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,8 @@ def scaled_mm(
729729
output_dtype: _dtype = ...,
730730
contraction_dim: list[int] | tuple[int, ...] = (),
731731
use_fast_accum: bool = False,
732+
*,
733+
out: Tensor | None = None,
732734
) -> Tensor: ...
733735

734736
__all__ += ["scaled_mm"]

torch/testing/_internal/common_cuda.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,16 @@ def evaluate_platform_supports_mxfp8_grouped_gemm():
191191
return built_with_mslk and IS_SM100
192192
return False
193193

194+
def evaluate_platform_supports_mxfp4_gemm():
195+
if torch.cuda.is_available():
196+
built_with_mslk = "USE_MSLK" in torch.__config__.show()
197+
return bool(torch.version.hip) or built_with_mslk
198+
199+
return False
200+
201+
194202
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
203+
PLATFORM_SUPPORTS_MXFP4_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp4_gemm())
195204
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
196205
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
197206
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm())

0 commit comments

Comments
 (0)