Skip to content

perf: Port TRT-LLM SM120/SM121 FP4 CUTLASS GEMM optimizations. Add PDL#3026

Merged
bkryu merged 6 commits intoflashinfer-ai:mainfrom
bkryu:sm120_mm_fp4_opt
Apr 10, 2026
Merged

perf: Port TRT-LLM SM120/SM121 FP4 CUTLASS GEMM optimizations. Add PDL#3026
bkryu merged 6 commits intoflashinfer-ai:mainfrom
bkryu:sm120_mm_fp4_opt

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Apr 9, 2026

📌 Description

Summary

Details

TRT-LLM kernel parameter port

Updated the SM120 FP4 GEMM kernel template (fp4_gemm_template_sm120.h) to match TRT-LLM's optimized configuration:

Parameter Before After
Mainloop schedule KernelScheduleAuto KernelTmaWarpSpecializedCooperative
Stage count StageCount<2> (fixed) StageCountAutoCarveout<sizeof(EpilogueSharedStorage)>
Tile scheduler void (data-parallel) StaticPersistentScheduler
Epilogue schedule EpilogueScheduleAuto TmaWarpSpecialized
Epilogue OpClass OpClassBlockScaledTensorOp OpClassTensorOp
  1. KernelScheduleAutoKernelTmaWarpSpecializedCooperative — Removes auto-resolution ambiguity. Explicitly selects the cooperative warp-specialized mainloop where dedicated warps handle TMA loads while others run MMA.
  2. StageCount<2>StageCountAutoCarveout<sizeof(EpilogueSharedStorage)> — Instead of a hardcoded 2-stage pipeline, dynamically computes how many stages fit in shared memory after reserving space for the epilogue. More stages = better latency hiding of TMA loads behind MMA compute.
  3. void (scheduler) → StaticPersistentScheduler — The old void scheduler launches one CTA per output tile (data-parallel). The persistent scheduler launches fewer CTAs that loop over multiple tiles, reducing kernel launch overhead — most impactful at small M where the kernel is short.
  4. EpilogueScheduleAutoTmaWarpSpecialized — Explicitly selects TMA-based epilogue with warp specialization for output writes, rather than relying on auto-resolution.
  5. OpClassBlockScaledTensorOpOpClassTensorOp (in epilogue builder only) — The epilogue doesn't need the block-scaled op class (that's only for the mainloop MMA). Using OpClassTensorOp matches what TRT-LLM uses and avoids potential misrouting in the epilogue collective builder.

The persistent scheduler reduces launch overhead (most impactful for small M), dynamic stage carveout adapts pipeline depth to available smem, and explicit cooperative warp specialization avoids auto-resolution ambiguity.

PDL enablement

Changed enablePDL=falseenablePDL=true in runFp4GemmImpl. The CUTLASS_ENABLE_GDC_FOR_SM100=1 compile flag is already set (since PR #2780), and both SM100 FP4 GEMM and SM120 MXFP8 GEMM already run with PDL enabled. The false was a stale leftover.

Performance Numbers on RTX 5090 (SM120) and DGX Spark (SM121)

Performance changes most relevant to SM120. Very minor on Spark

Click to view non-autotune backend=cutlass data
M N K RTX 5090 Before (us) RTX 5090 After (us) RTX 5090 Speedup Spark Before (us) Spark After (us) Spark Speedup
1 512 7168 33.44 24.21 1.38x 40.51 32.56 1.24x
1 896 1024 7.41 6.26 1.18x 11.49 10.93 1.05x
1 896 5120 25.41 18.22 1.39x 39.22 36.40 1.08x
1 1024 7168 34.30 24.11 1.42x 57.07 56.47 1.01x
1 1280 8192 39.25 27.49 1.43x 77.92 68.90 1.13x
1 1792 5120 26.43 19.22 1.38x 63.62 60.80 1.05x
1 2560 8192 40.54 29.12 1.39x 117.75 118.08 1.00x
1 3584 5120 27.81 21.89 1.27x 99.25 102.31 0.97x
1 4608 7168 39.87 32.24 1.24x 156.40 165.52 0.94x
1 5120 640 7.02 6.56 1.07x 25.98 27.81 0.93x
1 5120 1024 9.52 8.10 1.18x 34.85 36.08 0.97x
1 5120 1280 10.78 9.09 1.19x 44.58 44.08 1.01x
1 5120 2048 15.28 12.13 1.26x 61.92 66.31 0.93x
1 5120 2560 18.67 14.69 1.27x 74.34 79.49 0.94x
1 5120 4096 26.08 18.72 1.39x 112.26 114.30 0.98x
1 5120 5120 32.96 23.81 1.38x 134.56 128.22 1.05x
1 5120 8192 44.27 33.86 1.31x 185.91 186.32 1.00x
1 5120 16384 76.08 58.78 1.29x 363.43 350.37 1.04x
1 7168 256 5.25 4.83 1.09x 16.83 17.01 0.99x
1 7168 512 7.15 6.83 1.05x 29.23 30.48 0.96x
1 7168 4608 33.82 25.71 1.32x 164.31 160.51 1.02x
1 7168 5120 36.64 27.81 1.32x 172.45 168.47 1.02x
1 8192 1024 10.86 9.76 1.11x 52.06 53.04 0.98x
1 8192 2048 17.25 14.74 1.17x 93.04 97.04 0.96x
1 8192 3584 28.30 21.94 1.29x 146.85 145.09 1.01x
1 8192 4096 30.96 23.94 1.29x 157.25 156.91 1.00x
1 8192 7168 47.34 38.13 1.24x 237.52 225.07 1.06x
1 8192 8192 50.40 42.67 1.18x 275.88 257.28 1.07x
1 8192 14336 83.90 70.38 1.19x 434.72 392.92 1.11x
1 8192 28672 160.99 135.05 1.19x 829.34 744.18 1.11x
1 9216 7168 50.96 45.76 1.11x 245.39 247.07 0.99x
1 10240 8192 55.50 50.74 1.09x 333.86 295.54 1.13x
4 512 7168 33.33 24.29 1.37x 40.15 32.53 1.23x
4 896 1024 7.38 6.21 1.19x 12.19 10.86 1.12x
4 1024 7168 34.35 24.37 1.41x 60.08 57.09 1.05x
4 4608 7168 39.84 32.24 1.24x 161.06 156.94 1.03x
4 7168 256 5.12 5.04 1.02x 16.62 16.42 1.01x
4 7168 512 7.15 6.91 1.03x 28.90 30.21 0.96x
4 7168 2304 18.38 14.93 1.23x 94.98 98.29 0.97x
4 7168 4608 33.97 25.07 1.35x 165.94 164.06 1.01x
4 9216 7168 50.80 45.89 1.11x 253.52 252.55 1.00x
8 896 5120 25.62 18.45 1.39x 38.56 37.52 1.03x
8 1280 8192 39.02 27.62 1.41x 79.87 70.10 1.14x
8 1792 5120 26.32 19.17 1.37x 63.01 59.20 1.06x
8 2560 8192 40.42 28.91 1.40x 120.05 118.18 1.02x
8 3584 5120 27.55 21.82 1.26x 102.66 101.36 1.01x
8 5120 640 7.14 6.34 1.13x 25.58 27.07 0.95x
8 5120 1024 9.44 8.00 1.18x 35.81 36.62 0.98x
8 5120 1280 11.12 9.23 1.20x 42.69 44.38 0.96x
8 5120 2048 15.06 11.76 1.28x 63.92 67.30 0.95x
8 5120 2560 18.54 14.14 1.31x 73.60 78.85 0.93x
8 5120 4096 26.27 18.70 1.40x 115.06 112.67 1.02x
8 5120 5120 33.33 23.62 1.41x 132.08 132.91 0.99x
8 5120 8192 43.94 34.35 1.28x 191.14 182.59 1.05x
8 5120 16384 76.32 58.82 1.30x 368.77 302.40 1.22x
8 7168 5120 36.64 26.14 1.40x 172.30 175.75 0.98x
8 8192 1024 10.80 9.55 1.13x 50.45 53.57 0.94x
8 8192 2048 17.01 14.40 1.18x 93.47 95.66 0.98x
8 8192 3584 28.21 21.26 1.33x 144.23 147.04 0.98x
8 8192 4096 30.94 23.38 1.32x 154.69 151.47 1.02x
8 8192 7168 47.23 37.18 1.27x 244.32 223.83 1.09x
8 8192 8192 50.08 42.03 1.19x 280.23 267.97 1.05x
8 8192 14336 83.30 69.89 1.19x 418.08 403.96 1.03x
8 8192 28672 159.25 132.94 1.20x 829.61 726.61 1.14x
8 10240 8192 54.66 51.71 1.06x 328.87 301.07 1.09x
16 512 7168 33.34 24.21 1.38x 40.51 33.52 1.21x
16 896 1024 7.49 6.37 1.18x 11.89 11.41 1.04x
16 1024 7168 34.54 24.51 1.41x 60.06 56.43 1.06x
16 4608 7168 39.82 32.34 1.23x 163.44 166.00 0.98x
16 7168 256 5.22 5.26 0.99x 17.66 17.79 0.99x
16 7168 512 6.96 6.99 1.00x 28.75 29.84 0.96x
16 7168 2304 18.53 14.45 1.28x 94.67 97.73 0.97x
16 7168 4608 34.05 24.14 1.41x 163.09 158.32 1.03x
16 9216 7168 49.70 42.75 1.16x 253.16 237.79 1.06x
64 512 7168 33.18 24.14 1.37x 42.99 36.46 1.18x
64 896 1024 7.76 6.42 1.21x 12.22 11.07 1.10x
64 896 5120 25.36 18.24 1.39x 40.54 38.53 1.05x
64 1280 8192 38.89 27.33 1.42x 79.94 71.98 1.11x
64 1792 5120 26.22 19.20 1.37x 66.29 61.71 1.07x
64 2560 8192 40.29 28.72 1.40x 120.77 116.37 1.04x
64 3584 5120 27.98 22.18 1.26x 101.94 104.29 0.98x
64 4608 7168 39.97 32.61 1.23x 160.07 162.56 0.98x
64 5120 640 7.06 6.48 1.09x 28.54 30.43 0.94x
64 5120 1024 9.52 7.78 1.22x 38.27 39.98 0.96x
64 5120 1280 10.91 8.75 1.25x 44.70 46.42 0.96x
64 5120 2048 15.17 11.44 1.33x 63.87 68.66 0.93x
64 5120 2560 18.88 14.08 1.34x 76.58 81.49 0.94x
64 5120 4096 26.24 19.07 1.38x 111.94 114.82 0.97x
64 5120 5120 33.06 23.41 1.41x 133.20 130.75 1.02x
64 5120 8192 44.27 34.56 1.28x 185.28 185.75 1.00x
64 5120 16384 76.03 58.94 1.29x 359.46 345.33 1.04x
64 7168 256 5.28 5.12 1.03x 22.58 20.67 1.09x
64 7168 512 7.10 6.61 1.08x 34.05 32.66 1.04x
64 7168 2304 18.37 13.71 1.34x 97.34 96.88 1.00x
64 7168 4608 34.11 24.24 1.41x 165.67 161.65 1.02x
64 7168 5120 36.83 25.38 1.45x 172.74 168.59 1.02x
64 8192 1024 10.91 9.02 1.21x 53.84 58.08 0.93x
64 8192 2048 17.01 12.93 1.32x 94.61 100.69 0.94x
64 8192 3584 28.40 18.94 1.50x 148.11 145.75 1.02x
64 8192 4096 31.44 21.38 1.47x 160.27 157.68 1.02x
64 8192 7168 47.04 34.53 1.36x 240.47 225.95 1.06x
64 8192 8192 49.84 35.57 1.40x 273.67 266.80 1.03x
64 8192 14336 83.20 67.71 1.23x 424.20 393.78 1.08x
64 8192 28672 159.34 128.72 1.24x 822.04 832.97 0.99x
64 9216 7168 49.31 35.46 1.39x 249.94 245.95 1.02x
64 10240 8192 52.86 38.69 1.37x 327.35 310.00 1.06x
256 512 7168 35.41 24.53 1.44x 55.46 48.11 1.15x
256 896 1024 7.74 6.66 1.16x 17.84 17.95 0.99x
256 1024 7168 35.87 24.75 1.45x 76.00 63.39 1.20x
256 4608 7168 42.16 34.30 1.23x 164.87 163.31 1.01x
256 7168 256 5.25 5.20 1.01x 41.76 42.86 0.97x
256 7168 512 7.26 7.09 1.02x 50.51 49.62 1.02x
256 7168 2304 18.99 16.06 1.18x 104.43 101.62 1.03x
256 7168 4608 34.98 27.58 1.27x 163.86 162.93 1.01x
256 9216 7168 50.67 40.33 1.26x 260.03 271.68 0.96x
512 896 5120 26.91 18.96 1.42x 60.56 54.32 1.11x
512 1280 8192 41.28 28.72 1.44x 105.22 90.66 1.16x
512 1792 5120 28.21 20.16 1.40x 97.89 91.62 1.07x
512 2560 8192 42.99 30.00 1.43x 152.29 141.06 1.08x
512 3584 5120 32.16 24.62 1.31x 135.03 122.02 1.11x
512 5120 640 8.16 7.98 1.02x 63.87 61.34 1.04x
512 5120 1024 10.48 10.29 1.02x 66.77 68.75 0.97x
512 5120 1280 12.90 11.87 1.09x 73.70 71.49 1.03x
512 5120 2048 16.37 14.54 1.13x 92.21 89.97 1.02x
512 5120 2560 19.70 16.67 1.18x 101.78 99.81 1.02x
512 5120 4096 28.40 23.65 1.20x 135.20 132.02 1.02x
512 5120 5120 35.23 30.22 1.17x 161.71 157.74 1.03x
512 5120 8192 47.73 43.36 1.10x 247.28 219.99 1.12x
512 5120 16384 87.04 83.89 1.04x 514.40 455.68 1.13x
512 7168 5120 55.14 48.03 1.15x 196.05 190.27 1.03x
512 8192 1024 17.60 15.31 1.15x 89.22 90.63 0.98x
512 8192 2048 27.54 23.38 1.18x 119.07 116.29 1.02x
512 8192 3584 42.91 34.70 1.24x 172.43 161.55 1.07x
512 8192 4096 46.62 39.22 1.19x 181.83 170.42 1.07x
512 8192 7168 74.46 70.08 1.06x 294.77 290.51 1.01x
512 8192 8192 81.47 80.03 1.02x 361.88 318.24 1.14x
512 8192 14336 142.05 127.65 1.11x 595.53 590.39 1.01x
512 8192 28672 264.89 245.31 1.08x 1418.62 1580.70 0.90x
512 10240 8192 86.21 85.76 1.01x 421.49 386.77 1.09x
1024 512 7168 36.18 25.02 1.45x 79.36 78.77 1.01x
1024 896 1024 8.27 6.86 1.21x 34.40 34.05 1.01x
1024 1024 7168 36.22 26.24 1.38x 116.93 115.79 1.01x
1024 4608 7168 71.78 70.90 1.01x 284.24 245.55 1.16x
1024 7168 256 13.07 12.82 1.02x 101.09 102.47 0.99x
1024 7168 512 18.40 18.18 1.01x 110.42 105.12 1.05x
1024 7168 4608 72.54 72.72 1.00x 294.99 255.68 1.15x
1024 9216 7168 133.60 128.05 1.04x 488.16 440.69 1.11x

RTX 5090 (SM120) geomean: 1.24x (147 shapes)
DGX Spark (SM121) geomean: 1.03x (147 shapes)

Click to view autotuned backend=cutlass data
M N K RTX 5090 Before (us) RTX 5090 After (us) RTX 5090 Speedup Spark Before (us) Spark After (us) Spark Speedup
1 512 7168 18.02 18.22 0.99x 33.23 33.76 0.98x
1 896 1024 7.66 6.29 1.22x 11.46 11.15 1.03x
1 896 5120 19.04 18.11 1.05x 38.30 32.75 1.17x
1 1024 7168 18.50 18.43 1.00x 53.86 49.82 1.08x
1 1280 8192 20.62 20.61 1.00x 78.66 58.50 1.34x
1 1792 5120 17.73 17.71 1.00x 50.85 50.85 1.00x
1 2560 8192 23.30 23.36 1.00x 119.12 102.22 1.17x
1 3584 5120 21.01 21.41 0.98x 89.30 89.87 0.99x
1 4608 7168 31.65 31.52 1.00x 144.53 146.66 0.99x
1 5120 640 7.39 7.01 1.05x 25.14 27.23 0.92x
1 5120 1024 9.22 7.84 1.18x 30.08 31.68 0.95x
1 5120 1280 11.41 9.44 1.21x 36.99 37.17 1.00x
1 5120 2048 11.87 11.84 1.00x 54.77 57.01 0.96x
1 5120 2560 14.03 13.95 1.01x 66.35 67.98 0.98x
1 5120 4096 20.51 20.24 1.01x 103.73 104.11 1.00x
1 5120 5120 25.06 25.06 1.00x 123.17 126.34 0.97x
1 5120 8192 37.23 37.17 1.00x 174.11 174.40 1.00x
1 5120 16384 48.45 51.06 0.95x 283.86 322.80 0.88x
1 7168 256 5.39 5.39 1.00x 15.90 15.71 1.01x
1 7168 512 7.09 6.48 1.09x 24.83 25.63 0.97x
1 7168 4608 26.02 26.13 1.00x 153.87 151.95 1.01x
1 7168 5120 27.74 27.76 1.00x 161.59 161.07 1.00x
1 8192 1024 10.78 9.25 1.17x 45.50 46.02 0.99x
1 8192 2048 14.43 14.74 0.98x 84.37 84.59 1.00x
1 8192 3584 21.06 20.85 1.01x 135.01 143.39 0.94x
1 8192 4096 22.69 22.58 1.00x 145.15 153.68 0.94x
1 8192 7168 41.78 41.63 1.00x 217.84 220.16 0.99x
1 8192 8192 39.95 39.60 1.01x 236.24 241.62 0.98x
1 8192 14336 68.69 68.91 1.00x 382.00 369.38 1.03x
1 8192 28672 107.68 115.58 0.93x 678.04 660.97 1.03x
1 9216 7168 39.46 39.70 0.99x 232.34 235.35 0.99x
1 10240 8192 46.58 46.62 1.00x 271.54 272.16 1.00x
4 512 7168 18.11 18.19 1.00x 33.54 32.96 1.02x
4 896 1024 7.54 6.34 1.19x 11.55 12.13 0.95x
4 1024 7168 18.53 18.56 1.00x 52.91 50.34 1.05x
4 4608 7168 28.35 28.35 1.00x 151.09 151.84 1.00x
4 7168 256 5.23 5.20 1.01x 13.65 15.89 0.86x
4 7168 512 7.10 6.90 1.03x 24.40 24.98 0.98x
4 7168 2304 14.67 14.50 1.01x 85.19 84.83 1.00x
4 7168 4608 25.12 25.29 0.99x 152.78 154.56 0.99x
4 9216 7168 43.81 44.06 0.99x 234.69 229.35 1.02x
8 896 5120 18.78 17.94 1.05x 38.53 32.77 1.18x
8 1280 8192 20.82 20.74 1.00x 79.50 69.73 1.14x
8 1792 5120 17.90 17.81 1.01x 59.34 51.36 1.16x
8 2560 8192 24.88 24.94 1.00x 114.83 117.66 0.98x
8 3584 5120 21.02 21.20 0.99x 101.44 99.06 1.02x
8 5120 640 7.33 6.58 1.11x 25.76 26.70 0.96x
8 5120 1024 9.39 8.06 1.16x 30.50 31.87 0.96x
8 5120 1280 11.57 9.74 1.19x 36.91 35.79 1.03x
8 5120 2048 11.87 11.65 1.02x 56.43 56.43 1.00x
8 5120 2560 14.02 13.90 1.01x 68.37 68.45 1.00x
8 5120 4096 20.16 20.13 1.00x 103.52 103.87 1.00x
8 5120 5120 24.75 24.94 0.99x 121.55 119.50 1.02x
8 5120 8192 37.07 37.20 1.00x 173.04 172.53 1.00x
8 5120 16384 51.33 50.30 1.02x 283.96 326.53 0.87x
8 7168 5120 27.49 27.66 0.99x 161.65 159.49 1.01x
8 8192 1024 10.62 9.47 1.12x 44.70 46.62 0.96x
8 8192 2048 14.32 14.06 1.02x 84.16 86.10 0.98x
8 8192 3584 20.64 20.50 1.01x 131.86 139.46 0.95x
8 8192 4096 22.30 22.03 1.01x 145.59 150.77 0.97x
8 8192 7168 38.27 38.40 1.00x 214.35 218.35 0.98x
8 8192 8192 43.68 43.84 1.00x 243.68 241.41 1.01x
8 8192 14336 61.98 62.03 1.00x 367.09 369.95 0.99x
8 8192 28672 108.59 119.36 0.91x 655.06 666.76 0.98x
8 10240 8192 46.66 46.54 1.00x 272.71 274.53 0.99x
16 512 7168 18.35 18.16 1.01x 34.06 34.77 0.98x
16 896 1024 7.52 6.29 1.20x 11.55 11.46 1.01x
16 1024 7168 18.66 18.59 1.00x 56.22 52.14 1.08x
16 4608 7168 32.10 32.02 1.00x 145.23 144.00 1.01x
16 7168 256 5.31 5.20 1.02x 16.34 15.82 1.03x
16 7168 512 6.88 6.80 1.01x 25.22 24.83 1.02x
16 7168 2304 14.78 14.45 1.02x 86.53 88.06 0.98x
16 7168 4608 25.57 25.33 1.01x 155.12 150.93 1.03x
16 9216 7168 38.93 39.15 0.99x 233.67 236.93 0.99x
64 512 7168 18.22 18.37 0.99x 35.49 35.09 1.01x
64 896 1024 7.60 6.27 1.21x 12.72 12.06 1.05x
64 896 5120 19.07 17.73 1.08x 47.81 34.64 1.38x
64 1280 8192 20.66 20.42 1.01x 82.32 80.67 1.02x
64 1792 5120 17.95 19.47 0.92x 64.10 53.47 1.20x
64 2560 8192 21.90 21.76 1.01x 106.39 104.69 1.02x
64 3584 5120 21.01 20.88 1.01x 102.77 102.61 1.00x
64 4608 7168 28.14 25.92 1.09x 152.87 158.67 0.96x
64 5120 640 7.41 6.82 1.09x 27.22 29.89 0.91x
64 5120 1024 9.55 7.68 1.24x 37.41 35.33 1.06x
64 5120 1280 9.50 8.98 1.06x 43.71 46.82 0.93x
64 5120 2048 15.18 11.20 1.36x 58.35 59.60 0.98x
64 5120 2560 18.18 13.33 1.36x 68.53 71.06 0.96x
64 5120 4096 20.35 20.18 1.01x 104.83 107.23 0.98x
64 5120 5120 23.62 21.65 1.09x 125.73 127.04 0.99x
64 5120 8192 33.65 31.90 1.05x 175.30 177.41 0.99x
64 5120 16384 51.46 50.45 1.02x 283.51 338.10 0.84x
64 7168 256 5.38 5.36 1.00x 18.59 19.66 0.95x
64 7168 512 7.02 6.74 1.04x 33.78 29.65 1.14x
64 7168 2304 14.21 14.08 1.01x 106.10 90.93 1.17x
64 7168 4608 24.83 24.83 1.00x 163.09 155.30 1.05x
64 7168 5120 27.68 27.50 1.01x 161.51 161.57 1.00x
64 8192 1024 10.62 8.80 1.21x 54.75 52.88 1.04x
64 8192 2048 13.57 12.86 1.05x 88.58 88.56 1.00x
64 8192 3584 19.84 19.15 1.04x 134.35 154.62 0.87x
64 8192 4096 21.76 21.49 1.01x 143.79 154.64 0.93x
64 8192 7168 40.48 37.44 1.08x 217.87 223.06 0.98x
64 8192 8192 42.66 42.58 1.00x 235.81 243.09 0.97x
64 8192 14336 65.49 65.01 1.01x 369.62 361.91 1.02x
64 8192 28672 108.33 113.45 0.95x 667.00 664.82 1.00x
64 9216 7168 44.46 43.31 1.03x 235.20 230.72 1.02x
64 10240 8192 51.12 49.84 1.03x 271.65 276.34 0.98x
256 512 7168 20.96 20.90 1.00x 49.90 50.00 1.00x
256 896 1024 7.66 6.26 1.23x 17.71 17.74 1.00x
256 1024 7168 22.19 22.27 1.00x 73.68 71.20 1.03x
256 4608 7168 34.67 32.51 1.07x 170.64 156.79 1.09x
256 7168 256 5.49 5.52 0.99x 38.90 39.28 0.99x
256 7168 512 7.41 7.28 1.02x 52.51 47.41 1.11x
256 7168 2304 14.56 15.98 0.91x 109.94 99.42 1.11x
256 7168 4608 35.25 27.57 1.28x 159.07 153.97 1.03x
256 9216 7168 40.30 40.94 0.98x 250.43 235.38 1.06x
512 896 5120 18.94 18.90 1.00x 61.30 53.65 1.14x
512 1280 8192 23.07 22.99 1.00x 91.09 92.77 0.98x
512 1792 5120 22.64 22.83 0.99x 97.03 87.98 1.10x
512 2560 8192 34.77 29.50 1.18x 143.19 138.42 1.03x
512 3584 5120 32.75 24.94 1.31x 136.55 129.43 1.06x
512 5120 640 9.31 9.30 1.00x 61.04 60.90 1.00x
512 5120 1024 11.07 11.23 0.99x 62.82 65.34 0.96x
512 5120 1280 11.57 12.18 0.95x 75.86 73.46 1.03x
512 5120 2048 16.08 14.35 1.12x 89.41 85.70 1.04x
512 5120 2560 15.90 16.19 0.98x 99.34 96.16 1.03x
512 5120 4096 22.69 22.62 1.00x 128.18 135.65 0.94x
512 5120 5120 28.24 29.98 0.94x 157.41 153.09 1.03x
512 5120 8192 40.10 40.30 0.99x 201.84 204.10 0.99x
512 5120 16384 80.00 80.09 1.00x 385.19 450.56 0.85x
512 7168 5120 45.28 40.22 1.13x 175.25 176.79 0.99x
512 8192 1024 14.98 13.92 1.08x 84.94 88.22 0.96x
512 8192 2048 22.69 19.49 1.16x 121.91 114.98 1.06x
512 8192 3584 36.18 28.99 1.25x 154.24 157.12 0.98x
512 8192 4096 38.70 32.43 1.19x 168.34 168.74 1.00x
512 8192 7168 60.58 56.67 1.07x 261.33 257.76 1.01x
512 8192 8192 65.92 63.39 1.04x 288.55 296.61 0.97x
512 8192 14336 112.25 108.30 1.04x 470.79 471.16 1.00x
512 8192 28672 212.93 202.01 1.05x 980.86 1103.15 0.89x
512 10240 8192 75.14 78.35 0.96x 338.56 339.22 1.00x
1024 512 7168 22.05 22.02 1.00x 76.56 80.66 0.95x
1024 896 1024 8.29 6.67 1.24x 32.16 32.45 0.99x
1024 1024 7168 25.74 23.15 1.11x 104.77 106.29 0.99x
1024 4608 7168 59.04 56.91 1.04x 239.03 239.97 1.00x
1024 7168 256 12.45 12.85 0.97x 100.98 101.36 1.00x
1024 7168 512 17.39 18.02 0.97x 104.19 105.27 0.99x
1024 7168 4608 67.54 69.47 0.97x 249.79 251.81 0.99x
1024 9216 7168 115.14 117.41 0.98x 423.97 433.32 0.98x

RTX 5090 (SM120) geomean: 1.04x (147 shapes)
DGX Spark (SM121) geomean: 1.01x (147 shapes)

🔍 Related Issues

#3013

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Optimizations
    • Optimized FP4 GEMM kernel execution by refining scheduler configuration, epilogue fusion operations, and memory staging parameters to improve performance and resource utilization.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

CUTLASS FP4 GEMM template for SM120 updated to enable Programmatic Dependent Launch, restructure epilogue and mainloop scheduler configurations, and rewire kernel tile scheduling strategies from fixed staging to auto-carveout and static persistent scheduler parameters.

Changes

Cohort / File(s) Summary
FP4 GEMM Template Configuration
include/flashinfer/gemm/fp4_gemm_template_sm120.h
Enabled PDL in gemm.run(); replaced fixed StageCount and KernelScheduleAuto with StageCountAutoCarveout and KernelTmaWarpSpecializedCooperative; introduced FusionOperation for epilogue and switched to TmaWarpSpecialized scheduling; rewired DP kernel to use StaticPersistentScheduler and StreamK kernel scheduler template parameter ordering.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

Possibly related PRs

Suggested labels

run-ci, op: moe

Suggested reviewers

  • yzh119
  • jimmyzho
  • nv-yunzheq
  • jiahanc
  • yongwww
  • nvmbreughe

Poem

🐰 A kernel born anew, with PDL's gentle glow,
Schedulers dance in sync, in epilogue's soft flow,
Static persistence meets the TMA warp's design,
Auto-carveout stages make the timing align,
SM120 sparkles bright with optimized performance divine! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: porting TRT-LLM SM120/SM121 FP4 CUTLASS GEMM optimizations and enabling PDL.
Description check ✅ Passed The pull request description is comprehensive and well-structured, covering all template requirements with detailed explanations of changes and performance data.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu changed the title Sm120 mm fp4 opt perf: Port TRT-LLM SM120/SM121 FP4 CUTLASS GEMM optimizations. Add PDL Apr 9, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/gemm/fp4_gemm_template_sm120.h`:
- Around line 267-270: Wrap the FP4 SM120 kernel typedefs GemmKernelDefault and
GemmKernelStreamK with the Sm12xOnly architecture guard (same pattern used in
Sm10x11xOnly/Sm12x examples): create a Sm12xOnly wrapper that checks the
architecture at runtime, prints an error message when unsupported, calls
__trap(), and otherwise resolves to the underlying
cutlass::gemm::kernel::GemmUniversal instantiation (referencing
CollectiveMainloop, CollectiveEpilogue, TileSchedulerTag); replace the raw
typedefs with this guarded alias so the kernels bail out on non-SM12x hardware.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2be120a5-f83a-4f0d-8746-aa2d09a23182

📥 Commits

Reviewing files that changed from the base of the PR and between 77a179f and 589990a.

📒 Files selected for processing (1)
  • include/flashinfer/gemm/fp4_gemm_template_sm120.h

Comment thread include/flashinfer/gemm/fp4_gemm_template_sm120.h
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 9, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !531 has been created, and the CI pipeline #48156327 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu self-assigned this Apr 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the FP4 GEMM implementation for SM120/SM121 by enabling PDL and refactoring the collective builder configurations to use TmaWarpSpecialized schedules and dynamic stage carveout. The review feedback suggests further improving readability by defining explicit aliases for the epilogue and mainloop schedules, which would simplify the builder declarations.

Comment thread include/flashinfer/gemm/fp4_gemm_template_sm120.h
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #48156327: 11/20 passed

@bkryu bkryu merged commit c0adda7 into flashinfer-ai:main Apr 10, 2026
38 of 58 checks passed
@sjug
Copy link
Copy Markdown

sjug commented Apr 10, 2026

Thanks for the fix @bkryu, the problem is that StageCountAutoCarveout in cutlass doesn't calculate the SMEM correctly for SM120.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants