Wrapper support for prefill with KV cache#5
Merged
yzh119 merged 1 commit intobatch-prefillfrom Oct 11, 2023
Merged
Conversation
yzh119
approved these changes
Oct 11, 2023
jimmyzho
referenced
this pull request
in jimmyzho/flashinfer
Sep 25, 2025
…infer-ai#790) * Fix MHA kernel Summary: ATT Test Plan: Reviewers: Subscribers: Tasks: Tags: * Extend DualGemm to support batched mode (#5) Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set. * Decouple LayoutB0 and LayoutB1 in DualGemm The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously. In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above. * Remove comment as no longer relevant * Revert Fix MHA kernel --------- Co-authored-by: mikeiovine <mikeiovine@fb.com>
diptorupd
referenced
this pull request
in ROCm/flashinfer
Sep 29, 2025
This PR fixes some of the unit test failures that occur in Single
Decode. It also disables clang formatting of headers.
The clang format of headers causes compilation issues. The compiler is
unable to find `HIP WARP SYNC INTRINSICS` causing failures. Disabling
clang format fixes these issues
```
Start 1: MathTest
1/6 Test #1: MathTest ......................... Passed 3.31 sec
Start 2: PosEncTest
2/6 Test #2: PosEncTest ....................... Passed 3.36 sec
Start 3: CascadeTest
3/6 Test #3: CascadeTest ...................... Passed 3.35 sec
Start 4: PageTest
4/6 Test #4: PageTest ......................... Passed 114.08 sec
Start 5: SingleDecodeTest
5/6 Test #5: SingleDecodeTest ................. Passed 35.22 sec
Start 6: BatchDecodeTest
6/6 Test #6: BatchDecodeTest .................. Passed 559.75 sec
100% tests passed, 0 tests failed out of 6
Total Test time (real) = 719.07 sec
```
diptorupd
referenced
this pull request
in ROCm/flashinfer
Sep 29, 2025
In this PR, we add infra for enabling decode via flashinfer gpu_iface.
This PR does not change existing infrastructure and we can still build
decode using AOT and JIT.
Tested locally
```
Start 5: SingleDecodeTest
5/6 Test #5: SingleDecodeTest ................. Passed 35.12 sec
Start 6: BatchDecodeTest
6/6 Test #6: BatchDecodeTest .................. Passed 541.87 sec
```
We will have a follow up PR for enabling AOT decode using flashinfer
gpu_iface
diptorupd
referenced
this pull request
in ROCm/flashinfer
Sep 29, 2025
CPP test suite was using `hipified` headers. In this PR, we port over unit tests to use `gpu_iface`. This is necessary for us as the next step is to move the build infrastructure to use `gpu_iface`
This PR has been tested locally
```
Test project /root/flashinfer/libflashinfer/tests/hip/build
Start 1: MathTest
1/6 Test #1: MathTest ......................... Passed 3.40 sec
Start 2: PosEncTest
2/6 Test #2: PosEncTest ....................... Passed 3.40 sec
Start 3: CascadeTest
3/6 Test #3: CascadeTest ...................... Passed 985.27 sec
Start 4: PageTest
4/6 Test #4: PageTest ......................... Passed 112.40 sec
Start 5: SingleDecodeTest
5/6 Test #5: SingleDecodeTest ................. Passed 35.46 sec
Start 6: BatchDecodeTest
6/6 Test #6: BatchDecodeTest .................. Passed 556.81 sec
100% tests passed, 0 tests failed out of 6
```
To replicate the tests
```
cd flashinfer/libflashinfer/tests/hip
```
```
mkdir build && cd build/
```
```
cmake -DCMAKE_PREFIX_PATH=/root/libtorch -DCMAKE_CXX_COMPILER:PATH=/opt/rocm/bin/amdclang++ -DFLASHINFER_INCLUDE_DIRS=/root/flashinfer/libflashinfer/include/ ..
```
```
make
```
```
ctest
```
diptorupd
referenced
this pull request
in ROCm/flashinfer
Sep 29, 2025
In this PR I remove the `libtorch` dependency and removed
`test_page.cpp`. `test_page.cpp` is the only unit test that uses
libtorch. However, we also have a pytest for testing page. We will use
that for validation.
Removing the libtorch dependency will help us speed docker builds and
remove additional dependencies.
```Test project /root/flashinfer/libflashinfer/tests/hip/build
Start 1: MathTest
1/8 Test #1: MathTest ............................ Passed 0.31 sec
Start 2: PosEncTest
2/8 Test #2: PosEncTest .......................... Passed 0.31 sec
Start 3: CascadeTest
3/8 Test #3: CascadeTest ......................... Passed 1369.12 sec
Start 4: SingleDecodeTest
4/8 Test #4: SingleDecodeTest .................... Passed 7726.35 sec
Start 5: BatchDecodeTest
5/8 Test #5: BatchDecodeTest ..................... Passed 811.61 sec
Start 6: test_mfma_fp32_16x16x16fp16
6/8 Test #6: test_mfma_fp32_16x16x16fp16 ......... Passed 0.30 sec
Start 7: test_transpose_4x4_half_registers
7/8 Test #7: test_transpose_4x4_half_registers ... Passed 0.28 sec
Start 8: test_rowsum
8/8 Test #8: test_rowsum ......................... Passed 0.27 sec
100% tests passed, 0 tests failed out of 8
```
zhou-yuxin
pushed a commit
to zhou-yuxin/flashinfer
that referenced
this pull request
Mar 6, 2026
Refactor input layout.
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 23, 2026
…4.4.2 Belt-and-suspenders sanity check deferred from 2026-04-23 session. NGC 1.3.0rc5.post2 pins nvidia-cutlass-dsl==4.3.4; flashinfer requirements.txt declares >=4.4.2 and the ported kernels bridge via monkey-patches. Port-parity deltas are unaffected (both sides run under the same in-container DSL compiler) — this rerun only validates absolute numbers versus external references.
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 24, 2026
Code-reading review 2026-04-24: `convert_sf_to_mma_layout` is a pure `.view(...).permute(...)` strided view — it does not move data; the underlying GPU bytes ARE the input SF bytes. The kernel reads via `data_ptr()` + stride metadata, getting the same bytes TRT-LLM's kernel reads. TRT-LLM's `swizzle_sf(unswizzle_sf(sf, ...))` is a round-trip empirically verified byte-identical to the input SF. Both paths hand the CuteDSL kernel the same bytes. Also: the 6D layout (32, 4, m//128, 4, k//4, num_groups) uses M=128 as fundamental sub-tile REGARDLESS of tile_size. The 2CTA variant at tile_size=256 reads 2 adjacent m_tiles across two CTAs; the SF byte layout doesn't change. The mechanism originally proposed for this candidate (tile_size-dependent SF layout mismatch) was based on a misreading of the layout. Kept the abandoned sf_layout_diff_test.py attempt as a record — its .contiguous()-on-strided-view comparison produced a false 88.72 percent divergence report that was a test-harness artifact, not a real finding. The corrected interpretation supersedes that test's nominal verdict. Working suspicion now moves to moe_permute (JIT-compiled sibling of moe_sort in moe_utils.py) — consumes moe_sort's now-verified output, explicitly tile_size-parameterized, and has not been isolated by any prior probe. Candidates ruled out so far: - kernel bodies (deep audit) - flashinfer-ai#1 MbarrierArray shim (2026-04-23 revert experiment) - flashinfer-ai#2 moe_sort / routing tables (2026-04-24 self-consistency) - flashinfer-ai#4 SF layout conversion (2026-04-24 code reading) Candidates still open: flashinfer-ai#3 fence_proxy shim (low prior), flashinfer-ai#5 orchestration / buffer sizing, flashinfer-ai#6 top-level wrappers. moe_permute now promoted to primary suspect (wasn't cleanly separated in the original flashinfer-ai#2 entry; test script in progress).
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 24, 2026
…lu gap at tile_size=256 Now that correctness at tile_size=256 is established (previous commit reclassifies flashinfer-ai#3067 as test artifact), the real open work item is the perf gap that the audit had mis-attributed to the correctness bug. 2026-04-24 experiment, forced tile_size=256 on flashinfer via tuner.py patch (for tile_size in [256]): fi gemm1_swiglu at N=16384, tile=256: 2.644 ms trt gemm1_swiglu at N=16384, tile=256: 1.806 ms +46.4 percent at the SAME tactic At tile=128, flashinfer's gemm1 was 2.711 ms — so enabling tile=256 barely helps flashinfer, while TRT-LLM at the same tile=256 tactic gets much more throughput. The large-batch +27 percent top-line regression is NOT resolved by un-gating tile_size=256. Both sides compile the identical CuteDSL kernel source (deep audit established semantic identity of kernel bodies). So the SASS or runtime differs despite identical Python source. Working hypotheses: - 8a. Compile-time parameter / constexpr drift between wrappers' invocations causes different SASS - 8b. Launch-grid / cluster config mis-set on flashinfer (2CTA effectively running as 1CTA) - 8c. Input buffer alignment / stride differences the kernel optimizer exploits - 8d. Stream / cooperative-group sync context difference Suggested first probe: nsys trace at N=16384 for both sides at tile_size=256, compare launch config + SASS identifier + span alignment. Should quickly narrow 8a vs 8b vs 8d. Note on prior follow-up flashinfer-ai#1: this supersedes its framing. "MbarrierArray shim / tile_gating causes the regression" is now invalidated by the forced-tile=256 experiment showing un-gating does not recover the regression. The rest of flashinfer-ai#3067-era candidates (flashinfer-ai#3 fence_proxy, flashinfer-ai#5 orchestration, flashinfer-ai#6 top-level wrappers) were framed around a correctness bug that doesn't exist and are now subordinate. Follow-up flashinfer-ai#8 replaces them as the primary perf work item.
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 28, 2026
…ground-truth verification A single nsys trace at N=8192 with `--nsys-capture-range` (commit 40bb77e) bracketing only the timed measurement passes resolved both remaining measurement-related follow-ups. flashinfer-ai#4 (`bench_gpu_time_with_cupti(use_cuda_graph=True)` 2× inflation): direct wall-clock comparison at N=16384 / 30 iters shows identical wall-clocks with and without `--use-cupti` (1m12.5s vs 1m9.5s; the 3 s delta is autotune-compile + Python-startup variance, well below the ~240 ms of actual GPU measurement work in 70+ s of total wall-clock). The historical 2× signature was always a `cupti-python` span-attribution artifact, never real GPU work — and it does not reproduce under current methodology. A smaller asymmetric bias (~13% under-report on `trt_ms` vs ~5% on `fi_ms`) persists, which is the rationale for keeping `--use-cupti` opt-in (default off). flashinfer-ai#6 (in-bench vs standalone 19% gap on trt `gemm2_finalize`): nsys ground truth at N=8192 = 0.737 ms; current in-bench reports 0.7465 ms (1.3% delta — within noise); standalone reports 0.685 ms (7.1% below ground truth, harness-to-harness rounding tolerance). The original 19% gap was specific to the older `--use-cupti` config against the standalone — under current methodology there is no systematic bias. Audit changes: - New "Ground-truth nsys verification (2026-04-28)" section immediately after the post-fix verification, documenting the run command, per-kernel ground truth, the resolution of both follow-ups with quantitative tables, and a note that the trace also serves as a third independent kernel-port faithfulness check (kernel mangled-name structure matches modulo encoded module path). - Follow-up flashinfer-ai#1 marked RESOLVED (the original `MbarrierArray` framing was wrong; actual cause was the gemm2-enumeration gap fixed at d291d17e/f0cf8cd0 on the standalone PR branch). - Follow-ups flashinfer-ai#4 and flashinfer-ai#6 entries replaced with closure notes. - Top-of-file correction section title updated to "2026-04-24/25/28" and short summary expanded to mention the verification round. The original "open mysteries" list (flashinfer-ai#1, flashinfer-ai#4, flashinfer-ai#6, flashinfer-ai#8) is now fully closed. Items remaining in *Follow-ups queued* (flashinfer-ai#2, flashinfer-ai#3, flashinfer-ai#5, flashinfer-ai#7) are all scope-expansions, not investigations. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 28, 2026
…red-and-skipped
After 2026-04-24 / 25 / 28 we have five independent fi-vs-trt
agreement proofs:
1. source byte-identical with rc5.post2 (deep audit)
2. compiled PTX byte-identical at tile_size=256 (md5 401ebca6...)
3. per-call timing match within 0.1% at apples-to-apples tactic
4. 45 (size, EP) parity cells all pass within 0.5 FP4 step
5. nsys ground-truth verification with kernel mangled-name
structure agreement
The probability of fi+trt being wrong-but-agreeing across all five
is essentially zero, so a PyTorch FP4 third-reference check no
longer provides meaningful incremental confidence.
Additionally, `compute_reference_moe_fp4` has known limitations:
its PyTorch-eager FP4 simulation is stricter than the kernel's
actual FP4 representation, which made it ambiguous to interpret
during the original flashinfer-ai#3067 framing. Disagreement between bench
output and the reference would not unambiguously indicate a kernel
bug.
Cost is also non-trivial: Python-eager per-token / per-expert
loops would require running at a small problem-size subset to
keep wall-clock bearable.
Cost-to-incremental-confidence ratio is bad enough that this
follow-up is consciously skipped, not deferred. Future evidence
of a fi-vs-trt agreement that's actually wrong would re-elevate
it; otherwise no value.
Effective remaining open follow-ups: flashinfer-ai#3 (production-convention
scaling), flashinfer-ai#5 (cutlass-dsl 4.4.2 sanity rerun), flashinfer-ai#9 (EP=16 tactic-
divergence root cause).
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 28, 2026
…d-and-skipped flashinfer-ai#3 is a real coverage gap (at gs=1.0 the scale-conversion code paths run with degenerate values, so a divergent FP8_MAX or scale-convention mismatch between the two sides would silently produce identical output here and divergent output at non-trivial scales). Skipped because: - Both sides have their own internal scale-plumbing tests with non-trivial scales. TRT-LLM's tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py and flashinfer's tests/moe/test_cute_dsl_fused_moe.py both exercise non-trivial-gs configurations against PyTorch references. The bench wouldn't add coverage their CI doesn't already have. - A failure here wouldn't tell us what we want to know. The most likely cause of fi-vs-trt parity divergence at production scales would be a scale-convention disagreement BETWEEN the two implementations — not fixable in flashinfer (TRT-LLM is upstream), and dramatically out of scope. - Risk of getting nerd-sniped on bench-side mistakes. Scale plumbing has many surfaces (alpha, weight_scale_2, input_scale, is_sf_* flags, fp4_quantize signatures); any tiny bench-side mismatch produces a parity failure that looks like a port bug. That's the same failure pattern as flashinfer-ai#4 / flashinfer-ai#6 / flashinfer-ai#8 from earlier in the audit. Scope-difference note: the audit's load-bearing question is "is the kernel port faithful?" — closed via byte-identical PTX + matching timing + 45 parity cells. flashinfer-ai#3 is about "do flashinfer and TRT-LLM agree on NVFP4 scaling conventions?" — a separate question, real but tangential to port-faithfulness. One scenario that would re-elevate this: a planned production ship of CuteDslMoEWrapper to a caller using non-trivial scales, where the team wants one independent cross-check (against CuteDslFusedMoE under the same scales) before merging. In that scenario flashinfer-ai#3 is exactly the right pre-ship sanity. For closing out the investigation audit, it's tangential. Effective remaining open follow-ups: flashinfer-ai#5 (cutlass-dsl 4.4.2 sanity rerun) and flashinfer-ai#9 (EP=16 tactic-divergence root cause).
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 28, 2026
…flashinfer-ai#9 remains Three close-out edits to wrap the CuteDSL MoE FP4 port audit: - Mark flashinfer-ai#5 (cutlass-dsl 4.4.2 sanity rerun) considered-and-skipped, matching the closure pattern used for flashinfer-ai#2 and flashinfer-ai#3. Three reasons: (a) port-parity claims are unaffected by DSL-compiler version since both sides use the same in-container compiler, (b) flashinfer and TRT-LLM each have CI testing 4.4.2 already, (c) install hassle plus unsupported-config risk produces the same ambiguous-failure pattern that cost time on flashinfer-ai#4/flashinfer-ai#6/flashinfer-ai#8. Auto-resolves whenever NGC bumps the image. Install recipe preserved for future absolute-latency probes. - Add a "Version-skew caveat (2026-04-28)" subsection to the top-of-file correction. The bench compares flashinfer-with-post-rc5- forward-ports (bb2f88329, 6b8ae6fa8, fae498579) vs TRT-LLM-rc5.post2- without-them, so the +3.5% / +4.1% headlines at EP=1 N=16384 may partly reflect the version asymmetry. Load-bearing claims (port faithfulness via byte-identical source + PTX, 45/45 parity, flashinfer-ai#3067 fix) are unaffected because they do not depend on absolute deltas. Naturally re-baselines when NGC publishes a 1.3.x.x image absorbing the post-rc5 commits. - Update the "Open follow-ups remaining" summary: flashinfer-ai#5 added to the considered-and-skipped list alongside flashinfer-ai#2/flashinfer-ai#3, leaving flashinfer-ai#9 (EP=16 tactic-divergence root cause) as the only effective open follow-up. Audit declared closed 2026-04-28. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
leejnau
added a commit
to leejnau/flashinfer
that referenced
this pull request
Apr 29, 2026
…5.post2 at content level The "Version-skew caveat (2026-04-28)" section claimed three TRT-LLM commits (`bb2f88329`, `6b8ae6fa8`, `fae498579`) had landed upstream *after* the rc5.post2 cut and were forward-ported into flashinfer ahead of NGC. That claim was factually wrong and is retracted. Verification done 2026-04-29 (triggered by Lee's question about whether to upgrade NGC to v1.3.0rc13): - The three commits are dated 2026-01-06 / 2026-01-07 / 2026-01-27 — well before v1.3.0rc5.post2 was tagged on 2026-04-17. - All three are present in NGC's bundled rc5.post2 image, verified by content inspection: `raster_along_m` from `bb2f88329` appears 14 times in the gather kernel at the rc5.post2 tag; `blk_reduce_bf16` and `blk_reduce_fp32` from `fae498579` are present in `utils.py` at the rc5.post2 tag. Comprehensive port-completeness verification (now captured in the audit doc as the new *Apples-to-apples confirmation* sub-section): - flashinfer's port date for the four CuteDSL kernel files is 2026-02-06 (commit `99562e5e`, *"feat: cuteDSL fp4 moe for better DSR1 performance"* (flashinfer-ai#2398)). - All TRT-LLM commits to those four files dating from 2025-09-01 through fi's port date are present in fi's initial port (verified by distinctive content markers — fi's `utils.py` contains all 8 functions added by the `21a93fbf9` PDL/memset commit; gemm kernels contain `raster_along_m`; etc.). - The only TRT-LLM commit to those four files between fi's port date and v1.3.0rc5.post2 is `e92ee4fe5` (DWDP support, 2026-04-02), intentionally not ported (fi uses its own a2a path; LOW-A in the per-component analysis). - The only TRT-LLM commit to those four files between v1.3.0rc5.post2 and v1.3.0rc13 is also `e92ee4fe5`. So an NGC upgrade to rc13 only adds DWDP — which fi intentionally omits — and changes nothing fi cares about. - Compiled-output verification holds: at tile_size=256, fi's gemm1_swiglu PTX is byte-identical to rc5.post2's after symbol normalization (md5 401ebca6...; zero diff lines); per-call latency matches within 0.1%. Implications for residual Δ% interpretation: - The EP=1 +3.9% residual is purely wrapper-overhead (CuteDslMoEWrapper Python-layer cost vs trt's C++ thop wrapper). An earlier framing as "version-skew + wrapper-overhead" was retracted: the version-skew premise is false. - An NGC image bump does NOT auto-resolve the EP=1 residual. The thin-adapter refactor of `CuteDslMoEWrapper` is the only path that closes it. - The bench is comparing apples-to-apples at the kernel level. GPU kernel work is byte-identical; only the wrapper / orchestrator architecture differs (intentional, divergent_architecture). Edits in this commit: - Retitled and rewrote the "Version-skew caveat (2026-04-28)" section as "Apples-to-apples confirmation: fi is at parity with NGC TRT-LLM 1.3.0rc5.post2 at the kernel level (verified 2026-04-29)". New section explicitly captures the retraction, the content-level verification, and the implications. - Updated the *Final state* section's EP=1 paragraph: removed the "+ version-skew" framing, added the wrapper-overhead-only attribution with a back-reference to the apples-to-apples sub-section. - Updated the *Comprehensive 30-cell sweep* section's EP=1 paragraph: same correction. - Updated the *Fresh-container reproduction matrix* section's EP=1 explanation: same correction. - Updated the Executive Verdict's "Comprehensive runtime perf" paragraph to drop the "version-skew" attribution. - Updated the cutlass-dsl 4.4.2 follow-up (flashinfer-ai#5) to drop the cross-reference to the now-retracted version-skew caveat. - Updated the title's "last updated" tag to note the retraction. - Removed the duplicated "flashinfer root / TensorRT-LLM root / Post-port cutoff" footer triplet that arose from the rewrite. Memory file `project_cutedsl_moe_fp4_port_audit_closed.md` updated with the same corrections. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
17 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.