Skip to content

Commit ad8f739

Browse files
romanmetapytorchmergebot
authored andcommitted
[inductor] Decompose mm/addmm to pointwise mul when K==1 (#175825)
Summary: When K == 1, matrix multiplication (M, 1) @ (1, N) is an outer product. Instead of launching a full GEMM kernel, we decompose it into a broadcasted pointwise multiply at the ATen decomposition level, which is more efficient for this memory-bound case. This is a reland of D94097622 with two fixes: - Skip decomposition when M==1 or N==1 to avoid output strides from the broadcast multiply not matching mm strides. - Remove `as_strided` stride fixup that was causing issues with Helion (SympifyError on symbolic shapes). The M==1/N==1 guard also applies to the existing CPU K==1 decomposition path. **aten.mm** — TritonBench, K=1 shapes, median of 3 runs: | Shape (M, N, K) | B200 base (us) | B200 test (us) | B200 Speedup | H100 base (us) | H100 test (us) | H100 Speedup | |---|---|---|---|---|---|---| | (100, 100, 1) | 12.3 | 11.3 | 1.09x | 9.76 | 8.64 | 1.13x | | (150, 150, 1) | 12.3 | 11.2 | 1.10x | 9.82 | 8.70 | 1.13x | | (200, 200, 1) | 12.3 | 11.3 | 1.09x | 9.95 | 8.80 | 1.13x | | (256, 256, 1) | 12.3 | 11.3 | 1.09x | 9.76 | 8.70 | 1.12x | | (512, 512, 1) | 12.3 | 11.2 | 1.10x | 9.92 | 8.80 | 1.13x | | (1024, 1024, 1) | 14.3 | 13.2 | 1.09x | 11.39 | 9.44 | 1.21x | | (2048, 2048, 1) | 20.5 | 15.3 | **1.34x** | 16.19 | 12.83 | **1.26x** | | (4096, 4096, 1) | 35.8 | 26.8 | **1.33x** | 37.98 | 29.12 | **1.30x** | | (8192, 8192, 1) | 96.3 | 68.6 | **1.40x** | 120.48 | 89.12 | **1.35x** | | (16384, 16384, 1) | 329.8 | 234.5 | **1.41x** | 387.42 | 249.54 | **1.55x** | | (4608, 20, 1) | 13.2 | 11.3 | 1.17x | 10.02 | 8.86 | 1.13x | | (4608, 32, 1) | 13.2 | 11.3 | 1.17x | 9.95 | 8.86 | 1.12x | | (4608, 128, 1) | 13.2 | 11.4 | 1.17x | 10.94 | 8.99 | 1.22x | | (4608, 256, 1) | 14.3 | 13.2 | 1.09x | 12.22 | 9.50 | **1.29x** | | (4608, 512, 1) | 17.4 | 13.3 | **1.31x** | 14.02 | 10.59 | **1.32x** | | (4608, 1024, 1) | 20.5 | 15.3 | **1.34x** | 17.06 | 13.18 | **1.29x** | | (1024, 4096, 1) | 20.5 | 15.3 | **1.34x** | 16.80 | 13.25 | **1.27x** | | (4096, 1024, 1) | 20.5 | 15.3 | **1.34x** | 16.22 | 12.51 | **1.30x** | Geomean speedup: B200 **1.21x**, H100 **1.22x**, 0 regressions. **aten.addmm** — TritonBench, K=1 shapes, median of 3 runs: | Shape (M, N, K) | B200 base (us) | B200 test (us) | B200 Speedup | H100 base (us) | H100 test (us) | H100 Speedup | |---|---|---|---|---|---|---| | (100, 100, 1) | 12.3 | 12.3 | 1.00x | 9.76 | 9.06 | 1.08x | | (150, 150, 1) | 12.4 | 12.3 | 1.01x | 10.08 | 9.18 | 1.10x | | (200, 200, 1) | 12.4 | 12.3 | 1.00x | 9.98 | 9.31 | 1.07x | | (256, 256, 1) | 12.3 | 12.3 | 1.00x | 9.86 | 9.38 | 1.05x | | (512, 512, 1) | 13.3 | 13.2 | 1.01x | 10.37 | 9.73 | 1.07x | | (1024, 1024, 1) | 15.3 | 13.3 | 1.15x | 12.32 | 11.20 | 1.10x | | (2048, 2048, 1) | 23.6 | 18.5 | **1.27x** | 19.01 | 16.19 | **1.17x** | | (4096, 4096, 1) | 56.3 | 33.8 | **1.66x** | 58.72 | 45.60 | **1.29x** | | (8192, 8192, 1) | 172.2 | 102.3 | **1.68x** | 166.75 | 148.45 | 1.12x | | (16384, 16384, 1) | 665.8 | 359.5 | **1.85x** | 638.66 | 503.23 | **1.27x** | | (4608, 20, 1) | 13.2 | 12.3 | 1.07x | 10.21 | 9.47 | 1.08x | | (4608, 32, 1) | 13.2 | 12.4 | 1.06x | 10.11 | 9.47 | 1.07x | | (4608, 128, 1) | 13.3 | 13.2 | 1.00x | 11.68 | 10.27 | 1.14x | | (4608, 256, 1) | 15.3 | 13.4 | 1.14x | 13.28 | 11.55 | 1.15x | | (4608, 512, 1) | 18.6 | 15.4 | 1.20x | 15.87 | 13.63 | 1.16x | | (4608, 1024, 1) | 25.5 | 19.4 | **1.31x** | 21.02 | 17.92 | **1.17x** | | (1024, 4096, 1) | 23.5 | 18.5 | **1.27x** | 18.94 | 16.38 | 1.16x | | (4096, 1024, 1) | 23.5 | 18.5 | **1.27x** | 18.98 | 16.29 | 1.17x | Geomean speedup: B200 **1.19x**, H100 **1.13x**, 0 regressions. diff-train-skip-merge Test Plan: ``` PYTORCH_TEST_REMOTE_GPU=1 buck2 test //caffe2/test/inductor:test_mmdecomp_cuda \ -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_gpu_sections=true mode/opt Pass 30. Fail 0. PYTORCH_TEST_REMOTE_GPU=1 buck2 test //caffe2/test/inductor:test_mmdecomp \ -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_gpu_sections=true mode/opt Pass 29. Fail 0. PYTORCH_TEST_REMOTE_GPU=1 buck2 test //caffe2/test/inductor:fxir_backend \ -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_gpu_sections=true mode/opt Pass 76. Fail 0. ``` Reviewed By: PaulZhang12 Differential Revision: D94437532 Pull Request resolved: #175825 Approved by: https://github.com/PaulZhang12
1 parent 2ee3377 commit ad8f739

3 files changed

Lines changed: 24 additions & 8 deletions

File tree

test/inductor/test_fxir_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def foo(x, y):
394394

395395
# Expect separate forward and backward graphs.
396396
(forward_gm, backward_gm) = self._compile_and_check(
397-
foo, (x, y), expected_num_triton_kernels=3
397+
foo, (x, y), expected_num_triton_kernels=4
398398
)
399399

400400
def test_custom_compiler(self):

test/inductor/test_memory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ class Foo(torch.nn.Module):
3737

3838
def __init__(self):
3939
super().__init__()
40-
self.w1 = torch.nn.Parameter(torch.ones(1, 10))
41-
self.w2 = torch.nn.Parameter(torch.ones(1, 1))
40+
self.w1 = torch.nn.Parameter(torch.ones(2, 10))
41+
self.w2 = torch.nn.Parameter(torch.ones(2, 2))
4242
self.w3 = torch.nn.Parameter(torch.ones(10, 1))
43-
self.w4 = torch.nn.Parameter(torch.ones(1, 10))
43+
self.w4 = torch.nn.Parameter(torch.ones(2, 10))
4444

4545
def forward(self, x):
4646
t1 = torch.matmul(x, self.w1)
@@ -61,7 +61,7 @@ def setUp(self):
6161

6262
self.model = Foo().to(GPU_TYPE)
6363
M = 4096 if torch.version.hip is not None else 2048
64-
self.inputs = torch.ones((M, 1), device=GPU_TYPE)
64+
self.inputs = torch.ones((M, 2), device=GPU_TYPE)
6565
self.orig_reorder_method = memory.reorder_for_peak_memory
6666

6767
@mock.patch.object(config, "reorder_for_peak_memory", True)

torch/_inductor/decomposition.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,16 @@ def addmm(
349349
beta: torch.types.Number = 1,
350350
alpha: torch.types.Number = 1,
351351
) -> torch.Tensor:
352+
if mat1.device.type not in ["cpu", "mps"]:
353+
if (
354+
statically_known_true(mat1.size(-1) == 1)
355+
and statically_known_true(mat1.size(0) != 1)
356+
and statically_known_true(mat2.size(1) != 1)
357+
):
358+
counters["inductor"]["decompose_addmm"] += 1
359+
out = mat1 * mat2
360+
return alpha * out + beta * self
361+
352362
if self.device.type == "cpu":
353363
if statically_known_true(mat1.size(0) == 1) and statically_known_true(
354364
mat2.size(-1) == 1
@@ -386,16 +396,22 @@ def mm(
386396
input2.shape[1] == 1
387397
):
388398
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
389-
if self.device.type == "cpu":
390-
if (
391-
statically_known_true(self.size(-1) == 1)
399+
# Non-CPU/MPS: always decompose. CPU: only for small tensors.
400+
if (
401+
statically_known_true(self.size(-1) == 1)
402+
and statically_known_true(self.size(0) != 1)
403+
and statically_known_true(input2.size(1) != 1)
404+
):
405+
if self.device.type not in ["cpu", "mps"] or (
406+
self.device.type == "cpu"
392407
and statically_known_true(self.size(0) > 0)
393408
and statically_known_true(input2.size(0) == 1)
394409
and (self.dtype == input2.dtype)
395410
and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
396411
):
397412
counters["inductor"]["decompose_mm"] += 1
398413
return self * input2
414+
if self.device.type == "cpu":
399415
if statically_known_true(self.size(0) == 1) and statically_known_true(
400416
input2.size(-1) == 1
401417
):

0 commit comments

Comments
 (0)