perf: Port TRT-LLM SM120/SM121 FP4 CUTLASS GEMM optimizations. Add PDL#3026
perf: Port TRT-LLM SM120/SM121 FP4 CUTLASS GEMM optimizations. Add PDL#3026bkryu merged 6 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughCUTLASS FP4 GEMM template for SM120 updated to enable Programmatic Dependent Launch, restructure epilogue and mainloop scheduler configurations, and rewire kernel tile scheduling strategies from fixed staging to auto-carveout and static persistent scheduler parameters. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/gemm/fp4_gemm_template_sm120.h`:
- Around line 267-270: Wrap the FP4 SM120 kernel typedefs GemmKernelDefault and
GemmKernelStreamK with the Sm12xOnly architecture guard (same pattern used in
Sm10x11xOnly/Sm12x examples): create a Sm12xOnly wrapper that checks the
architecture at runtime, prints an error message when unsupported, calls
__trap(), and otherwise resolves to the underlying
cutlass::gemm::kernel::GemmUniversal instantiation (referencing
CollectiveMainloop, CollectiveEpilogue, TileSchedulerTag); replace the raw
typedefs with this guarded alias so the kernels bail out on non-SM12x hardware.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2be120a5-f83a-4f0d-8746-aa2d09a23182
📒 Files selected for processing (1)
include/flashinfer/gemm/fp4_gemm_template_sm120.h
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request updates the FP4 GEMM implementation for SM120/SM121 by enabling PDL and refactoring the collective builder configurations to use TmaWarpSpecialized schedules and dynamic stage carveout. The review feedback suggests further improving readability by defining explicit aliases for the epilogue and mainloop schedules, which would simplify the builder declarations.
|
[FAILED] Pipeline #48156327: 11/20 passed |
|
Thanks for the fix @bkryu, the problem is that |
📌 Description
Summary
Details
TRT-LLM kernel parameter port
Updated the SM120 FP4 GEMM kernel template (
fp4_gemm_template_sm120.h) to match TRT-LLM's optimized configuration:KernelScheduleAutoKernelTmaWarpSpecializedCooperativeStageCount<2>(fixed)StageCountAutoCarveout<sizeof(EpilogueSharedStorage)>void(data-parallel)StaticPersistentSchedulerEpilogueScheduleAutoTmaWarpSpecializedOpClassBlockScaledTensorOpOpClassTensorOpKernelScheduleAuto→KernelTmaWarpSpecializedCooperative— Removes auto-resolution ambiguity. Explicitly selects the cooperative warp-specialized mainloop where dedicated warps handle TMA loads while others run MMA.StageCount<2>→StageCountAutoCarveout<sizeof(EpilogueSharedStorage)>— Instead of a hardcoded 2-stage pipeline, dynamically computes how many stages fit in shared memory after reserving space for the epilogue. More stages = better latency hiding of TMA loads behind MMA compute.StaticPersistentScheduler— The old void scheduler launches one CTA per output tile (data-parallel). The persistent scheduler launches fewer CTAs that loop over multiple tiles, reducing kernel launch overhead — most impactful at small M where the kernel is short.EpilogueScheduleAuto→TmaWarpSpecialized— Explicitly selects TMA-based epilogue with warp specialization for output writes, rather than relying on auto-resolution.OpClassBlockScaledTensorOp→OpClassTensorOp(in epilogue builder only) — The epilogue doesn't need the block-scaled op class (that's only for the mainloop MMA). Using OpClassTensorOp matches what TRT-LLM uses and avoids potential misrouting in the epilogue collective builder.The persistent scheduler reduces launch overhead (most impactful for small M), dynamic stage carveout adapts pipeline depth to available smem, and explicit cooperative warp specialization avoids auto-resolution ambiguity.
PDL enablement
Changed
enablePDL=false→enablePDL=trueinrunFp4GemmImpl. TheCUTLASS_ENABLE_GDC_FOR_SM100=1compile flag is already set (since PR #2780), and both SM100 FP4 GEMM and SM120 MXFP8 GEMM already run with PDL enabled. Thefalsewas a stale leftover.Performance Numbers on RTX 5090 (SM120) and DGX Spark (SM121)
Performance changes most relevant to SM120. Very minor on Spark
Click to view non-autotune backend=cutlass data
RTX 5090 (SM120) geomean: 1.24x (147 shapes)
DGX Spark (SM121) geomean: 1.03x (147 shapes)
Click to view autotuned backend=cutlass data
RTX 5090 (SM120) geomean: 1.04x (147 shapes)
DGX Spark (SM121) geomean: 1.01x (147 shapes)
🔍 Related Issues
#3013
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit