Commit ad8f739
[inductor] Decompose mm/addmm to pointwise mul when K==1 (#175825)
Summary:
When K == 1, matrix multiplication (M, 1) @ (1, N) is an outer product.
Instead of launching a full GEMM kernel, we decompose it into a broadcasted
pointwise multiply at the ATen decomposition level, which is more efficient
for this memory-bound case.
This is a reland of D94097622 with two fixes:
- Skip decomposition when M==1 or N==1 to avoid output strides
from the broadcast multiply not matching mm strides.
- Remove `as_strided` stride fixup that was causing issues with Helion
(SympifyError on symbolic shapes).
The M==1/N==1 guard also applies to the existing CPU K==1 decomposition path.
**aten.mm** — TritonBench, K=1 shapes, median of 3 runs:
| Shape (M, N, K) | B200 base (us) | B200 test (us) | B200 Speedup | H100 base (us) | H100 test (us) | H100 Speedup |
|---|---|---|---|---|---|---|
| (100, 100, 1) | 12.3 | 11.3 | 1.09x | 9.76 | 8.64 | 1.13x |
| (150, 150, 1) | 12.3 | 11.2 | 1.10x | 9.82 | 8.70 | 1.13x |
| (200, 200, 1) | 12.3 | 11.3 | 1.09x | 9.95 | 8.80 | 1.13x |
| (256, 256, 1) | 12.3 | 11.3 | 1.09x | 9.76 | 8.70 | 1.12x |
| (512, 512, 1) | 12.3 | 11.2 | 1.10x | 9.92 | 8.80 | 1.13x |
| (1024, 1024, 1) | 14.3 | 13.2 | 1.09x | 11.39 | 9.44 | 1.21x |
| (2048, 2048, 1) | 20.5 | 15.3 | **1.34x** | 16.19 | 12.83 | **1.26x** |
| (4096, 4096, 1) | 35.8 | 26.8 | **1.33x** | 37.98 | 29.12 | **1.30x** |
| (8192, 8192, 1) | 96.3 | 68.6 | **1.40x** | 120.48 | 89.12 | **1.35x** |
| (16384, 16384, 1) | 329.8 | 234.5 | **1.41x** | 387.42 | 249.54 | **1.55x** |
| (4608, 20, 1) | 13.2 | 11.3 | 1.17x | 10.02 | 8.86 | 1.13x |
| (4608, 32, 1) | 13.2 | 11.3 | 1.17x | 9.95 | 8.86 | 1.12x |
| (4608, 128, 1) | 13.2 | 11.4 | 1.17x | 10.94 | 8.99 | 1.22x |
| (4608, 256, 1) | 14.3 | 13.2 | 1.09x | 12.22 | 9.50 | **1.29x** |
| (4608, 512, 1) | 17.4 | 13.3 | **1.31x** | 14.02 | 10.59 | **1.32x** |
| (4608, 1024, 1) | 20.5 | 15.3 | **1.34x** | 17.06 | 13.18 | **1.29x** |
| (1024, 4096, 1) | 20.5 | 15.3 | **1.34x** | 16.80 | 13.25 | **1.27x** |
| (4096, 1024, 1) | 20.5 | 15.3 | **1.34x** | 16.22 | 12.51 | **1.30x** |
Geomean speedup: B200 **1.21x**, H100 **1.22x**, 0 regressions.
**aten.addmm** — TritonBench, K=1 shapes, median of 3 runs:
| Shape (M, N, K) | B200 base (us) | B200 test (us) | B200 Speedup | H100 base (us) | H100 test (us) | H100 Speedup |
|---|---|---|---|---|---|---|
| (100, 100, 1) | 12.3 | 12.3 | 1.00x | 9.76 | 9.06 | 1.08x |
| (150, 150, 1) | 12.4 | 12.3 | 1.01x | 10.08 | 9.18 | 1.10x |
| (200, 200, 1) | 12.4 | 12.3 | 1.00x | 9.98 | 9.31 | 1.07x |
| (256, 256, 1) | 12.3 | 12.3 | 1.00x | 9.86 | 9.38 | 1.05x |
| (512, 512, 1) | 13.3 | 13.2 | 1.01x | 10.37 | 9.73 | 1.07x |
| (1024, 1024, 1) | 15.3 | 13.3 | 1.15x | 12.32 | 11.20 | 1.10x |
| (2048, 2048, 1) | 23.6 | 18.5 | **1.27x** | 19.01 | 16.19 | **1.17x** |
| (4096, 4096, 1) | 56.3 | 33.8 | **1.66x** | 58.72 | 45.60 | **1.29x** |
| (8192, 8192, 1) | 172.2 | 102.3 | **1.68x** | 166.75 | 148.45 | 1.12x |
| (16384, 16384, 1) | 665.8 | 359.5 | **1.85x** | 638.66 | 503.23 | **1.27x** |
| (4608, 20, 1) | 13.2 | 12.3 | 1.07x | 10.21 | 9.47 | 1.08x |
| (4608, 32, 1) | 13.2 | 12.4 | 1.06x | 10.11 | 9.47 | 1.07x |
| (4608, 128, 1) | 13.3 | 13.2 | 1.00x | 11.68 | 10.27 | 1.14x |
| (4608, 256, 1) | 15.3 | 13.4 | 1.14x | 13.28 | 11.55 | 1.15x |
| (4608, 512, 1) | 18.6 | 15.4 | 1.20x | 15.87 | 13.63 | 1.16x |
| (4608, 1024, 1) | 25.5 | 19.4 | **1.31x** | 21.02 | 17.92 | **1.17x** |
| (1024, 4096, 1) | 23.5 | 18.5 | **1.27x** | 18.94 | 16.38 | 1.16x |
| (4096, 1024, 1) | 23.5 | 18.5 | **1.27x** | 18.98 | 16.29 | 1.17x |
Geomean speedup: B200 **1.19x**, H100 **1.13x**, 0 regressions.
diff-train-skip-merge
Test Plan:
```
PYTORCH_TEST_REMOTE_GPU=1 buck2 test //caffe2/test/inductor:test_mmdecomp_cuda \
-c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 \
-c fbcode.enable_gpu_sections=true mode/opt
Pass 30. Fail 0.
PYTORCH_TEST_REMOTE_GPU=1 buck2 test //caffe2/test/inductor:test_mmdecomp \
-c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 \
-c fbcode.enable_gpu_sections=true mode/opt
Pass 29. Fail 0.
PYTORCH_TEST_REMOTE_GPU=1 buck2 test //caffe2/test/inductor:fxir_backend \
-c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 \
-c fbcode.enable_gpu_sections=true mode/opt
Pass 76. Fail 0.
```
Reviewed By: PaulZhang12
Differential Revision: D94437532
Pull Request resolved: #175825
Approved by: https://github.com/PaulZhang121 parent 2ee3377 commit ad8f739
3 files changed
Lines changed: 24 additions & 8 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
394 | 394 | | |
395 | 395 | | |
396 | 396 | | |
397 | | - | |
| 397 | + | |
398 | 398 | | |
399 | 399 | | |
400 | 400 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
41 | | - | |
| 40 | + | |
| 41 | + | |
42 | 42 | | |
43 | | - | |
| 43 | + | |
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
64 | | - | |
| 64 | + | |
65 | 65 | | |
66 | 66 | | |
67 | 67 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
349 | 349 | | |
350 | 350 | | |
351 | 351 | | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
352 | 362 | | |
353 | 363 | | |
354 | 364 | | |
| |||
386 | 396 | | |
387 | 397 | | |
388 | 398 | | |
389 | | - | |
390 | | - | |
391 | | - | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
392 | 407 | | |
393 | 408 | | |
394 | 409 | | |
395 | 410 | | |
396 | 411 | | |
397 | 412 | | |
398 | 413 | | |
| 414 | + | |
399 | 415 | | |
400 | 416 | | |
401 | 417 | | |
| |||
0 commit comments