Skip to content

make DeepGEMM swapAB available for linear gemm SM90#2131

Merged
yzh119 merged 9 commits intoflashinfer-ai:mainfrom
katec846:vchen/dg_swapab_linear
Dec 17, 2025
Merged

make DeepGEMM swapAB available for linear gemm SM90#2131
yzh119 merged 9 commits intoflashinfer-ai:mainfrom
katec846:vchen/dg_swapab_linear

Conversation

@katec846
Copy link
Copy Markdown
Contributor

@katec846 katec846 commented Nov 22, 2025

📌 Description

This PR is adding W8A8 on top of PR2101

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • FP8 block-scale GEMM for SM90 (Hopper): supports BF16/FP8 input combinations, per-token/per-block scales, runtime-dtype dispatch, workspace configuration, and robust input/output validation.
  • Chores
    • Exposed SM90 FP8 GEMM runner in package APIs and added a JIT module generator to build/load the SM90 kernel.
  • Tests
    • Added comprehensive tests for shapes, dtypes, scales, preallocated outputs, correctness thresholds, hardware guards, and error handling.
  • Refactor
    • DeepGemm enablement now evaluated at runtime for dynamic behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 22, 2025

Walkthrough

Adds an SM90-optimized FP8 block-scale GEMM path: a CUDA TVM-FFI binding and runner with runtime dtype dispatch and workspace management, Python JIT spec and high-level API wiring, unit tests, and a runtime DeepGemm enablement accessor.

Changes

Cohort / File(s) Summary
CUDA Binding & Runner
csrc/fp8_blockscale_gemm_sm90_binding.cu
New Fp8BlockScaleGemmRunner TVM-FFI module: constructs three Cutlass runners (BF16-BF16-BF16, BF16-FP8-BF16, FP8-FP8-BF16), exposes run_gemm, get_workspace_size, configure_workspace, performs pointer/shape/scale validation, runtime dtype dispatch, workspace sizing/management, and exports init() via FFI.
Cutlass Kernel Runtime Flag
csrc/.../fp8_blockscale_gemm_kernel.cuh
Adds getDeepGemmEnabled() and replaces most static flag uses with runtime calls so DeepGemm enablement is evaluated at runtime while retaining the static variable.
Python GEMM High-Level API
flashinfer/gemm/gemm_base.py
Adds get_fp8_blockscale_gemm_runner_sm90() loader and fp8_blockscale_gemm_sm90(...) API: SM90 gating, K%128 check, dtype/scale/shape validation, FP8 quantization/prep, workspace allocation/configuration, and invocation of the SM90 runner. (Note: duplicated entrypoints present in the patch.)
Python Package Exports
flashinfer/gemm/__init__.py
Exposes fp8_blockscale_gemm_sm90 in __all__.
JIT Module Generation
flashinfer/jit/gemm/fp8_blockscale.py, flashinfer/jit/gemm/__init__.py
Adds gen_fp8_blockscale_gemm_sm90_module(use_fast_build=False) JIT spec: source list (including fp8_blockscale_gemm.cu and binding), SM90/TMA/BF16/FP8 nvcc flags (cond. enable block-scale FP8 for CUDA>=12.8), include/linker flags, and exports the generator.
Tests
tests/gemm/test_fp8_blockscale_gemm.py
New comprehensive tests: BF16/FP8 combinations (including W8A8), per-token/per-block scales, shapes, error cases, pre-allocated output support, cosine-similarity checks, and SM90 capability gating.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User as User/App
    participant PyAPI as Python API\n(fp8_blockscale_gemm_sm90)
    participant JIT as JIT Module Loader
    participant Runner as SM90 Runner\n(Fp8BlockScaleGemmRunner)
    participant Kernel as Cutlass Kernel

    User->>PyAPI: call fp8_blockscale_gemm_sm90(input, weight, scales...)
    PyAPI->>PyAPI: validate arch, dtypes, K%128, shapes, scales
    PyAPI->>JIT: gen_fp8_blockscale_gemm_sm90_module() -> Module (if needed)
    PyAPI->>Runner: configure_workspace(workspace_buffer)
    PyAPI->>Runner: run_gemm(input_ptr, weight_ptr, output_ptr, scale_ptrs, shapes)
    Runner->>Runner: selectRunner(input_is_fp8, weight_is_fp8)
    Runner->>Kernel: launch selected Cutlass GEMM (workspace, scales)
    Kernel-->>Runner: kernel completes
    Runner-->>PyAPI: status / result
    PyAPI-->>User: return output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Review focus:
    • csrc/fp8_blockscale_gemm_sm90_binding.cu — runtime selection, pointer/shape/scale validation, workspace alignment/lifetime, FFI signatures.
    • flashinfer/gemm/gemm_base.py — quantization/scale handling and duplicated API definitions.
    • csrc/.../fp8_blockscale_gemm_kernel.cuh — runtime getDeepGemmEnabled semantics and side effects.
    • tests/gemm/test_fp8_blockscale_gemm.py — hardware gating and test coverage correctness.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • yongwww
  • cyx-6
  • wenscarl
  • bkryu
  • nvmbreughe
  • jimmyzho
  • jiahanc

Poem

🐇 I hopped through kernels, scales in paw,
Matrices aligned beneath my claw,
Runners chose the fastest track,
Workspaces set — no looking back,
SM90 bells wink — hop, hop, hurrah!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The PR title 'make DeepGEMM swapAB available for linear gemm SM90' is specific and directly related to the main objective described in PR objectives, but the actual changeset primarily implements FP8 block-scale GEMM functionality (W8A8 support) with supporting infrastructure, which is broader than just making DeepGEMM swapAB available. Clarify whether the title should emphasize 'Add FP8 block-scale GEMM (W8A8) support for SM90' or confirm if DeepGEMM swapAB is the primary focus of this changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 72c4f7bf4e54d69bf22b488cfc2683547cc9e66d and c2e0540.

📒 Files selected for processing (3)
  • csrc/fp8_blockscale_gemm_sm90_binding.cu (1 hunks)
  • flashinfer/gemm/gemm_base.py (2 hunks)
  • tests/gemm/test_fp8_blockscale_gemm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fp8_blockscale_gemm_sm90_binding.cu
  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (3)
csrc/fp8_blockscale_gemm_sm90_binding.cu (1)
csrc/tvm_ffi_utils.h (2)
  • encode_dlpack_dtype (30-32)
  • get_stream (294-296)
flashinfer/gemm/gemm_base.py (2)
flashinfer/jit/gemm/fp8_blockscale.py (1)
  • gen_fp8_blockscale_gemm_sm90_module (10-56)
csrc/fp8_blockscale_gemm_sm90_binding.cu (10)
  • init (235-238)
  • init (235-235)
  • input (98-176)
  • input (98-99)
  • input (178-198)
  • input (178-179)
  • input_is_fp8 (84-96)
  • input_is_fp8 (84-85)
  • workspace (213-220)
  • workspace (213-213)
tests/gemm/test_fp8_blockscale_gemm.py (5)
flashinfer/gemm/gemm_base.py (1)
  • fp8_blockscale_gemm_sm90 (3307-3521)
flashinfer/testing/utils.py (1)
  • per_token_cast_to_fp8 (39-46)
flashinfer/utils.py (2)
  • get_compute_capability (258-261)
  • is_sm90a_supported (531-533)
flashinfer/jit/gemm/fp8_blockscale.py (1)
  • gen_fp8_blockscale_gemm_sm90_module (10-56)
flashinfer/jit/core.py (1)
  • build_jit_specs (395-417)
🪛 Ruff (0.14.8)
flashinfer/gemm/gemm_base.py

3391-3393: Avoid specifying long messages outside the exception class

(TRY003)


3397-3397: Avoid specifying long messages outside the exception class

(TRY003)


3399-3399: Avoid specifying long messages outside the exception class

(TRY003)


3405-3407: Avoid specifying long messages outside the exception class

(TRY003)


3412-3414: Avoid specifying long messages outside the exception class

(TRY003)


3417-3417: Avoid specifying long messages outside the exception class

(TRY003)


3427-3429: Avoid specifying long messages outside the exception class

(TRY003)


3434-3434: Avoid specifying long messages outside the exception class

(TRY003)


3436-3436: Avoid specifying long messages outside the exception class

(TRY003)


3438-3441: Avoid specifying long messages outside the exception class

(TRY003)


3444-3447: Avoid specifying long messages outside the exception class

(TRY003)


3449-3452: Avoid specifying long messages outside the exception class

(TRY003)


3456-3456: Avoid specifying long messages outside the exception class

(TRY003)


3463-3467: Avoid specifying long messages outside the exception class

(TRY003)


3469-3469: Avoid specifying long messages outside the exception class

(TRY003)


3472-3475: Avoid specifying long messages outside the exception class

(TRY003)


3477-3480: Avoid specifying long messages outside the exception class

(TRY003)


3485-3487: Avoid specifying long messages outside the exception class

(TRY003)


3489-3491: Avoid specifying long messages outside the exception class

(TRY003)


3493-3495: Avoid specifying long messages outside the exception class

(TRY003)


3497-3499: Avoid specifying long messages outside the exception class

(TRY003)


3505-3507: Avoid specifying long messages outside the exception class

(TRY003)

tests/gemm/test_fp8_blockscale_gemm.py

1-1: The file is executable but no shebang is present

(EXE002)


326-326: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


355-355: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jhaotingc jhaotingc force-pushed the vchen/dg_swapab_linear branch from 8086118 to 9bbf63f Compare December 9, 2025 19:49
@katec846 katec846 marked this pull request as ready for review December 10, 2025 17:16
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (6)
tests/gemm/test_fp8_blockscale_gemm.py (5)

1-15: Minor: Copyright year is 2024 for new 2025 code.

The copyright header shows 2024, but this is new code being added in 2025. Consider updating to 2025.


32-45: Unused utility function calc_diff.

This function is defined but never called in any of the tests - they all use F.cosine_similarity instead. Consider either using this function for consistency with TRT-LLM's metric (as documented in the docstring), or removing it to avoid dead code.


234-261: Consider vectorizing dequantization and removing print statement.

  1. The nested Python loops (lines 235-248) could be vectorized using tensor operations for better readability and performance:
# Vectorized dequantization
input_dequant = input_fp8.to(torch.bfloat16).view(m, k // 128, 128) * input_scale.unsqueeze(-1)
input_dequant = input_dequant.view(m, k)
  1. The print statement at line 261 will add noise to test output. Consider using pytest's logging or removing it.

264-306: LGTM!

Good test coverage for per-block weight scales. Minor suggestion: consider removing the print statement at line 305 for cleaner test output.


366-396: Use raw strings for regex patterns in pytest.raises(match=...).

The patterns contain regex metacharacters (. and *). While they work as intended, using raw strings (r"...") makes the intent clearer and follows Python best practices for regex patterns.

-    with pytest.raises(ValueError, match="FP8.*or BF16"):
+    with pytest.raises(ValueError, match=r"FP8.*or BF16"):
         fp8_blockscale_gemm_sm90(input, weight)
-    with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"):
+    with pytest.raises(ValueError, match=r"FP8 input.*BF16 weight.*not supported"):
         fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale, None)
flashinfer/gemm/gemm_base.py (1)

3381-3385: The "90a" check is ineffective.

The _match_sm_version function constructs device_arch as f"{major * 10 + minor}", producing "90" for SM90/SM90a devices. The "90a" string in the list will never match since the 'a' suffix doesn't affect the reported compute capability.

-    if not _match_sm_version(input.device, ["90", "90a"]):
+    if not _match_sm_version(input.device, ["90"]):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6bb01d1 and 710d0a7f91ab3f386e7c4b00f4256d67d2a50ac8.

📒 Files selected for processing (6)
  • csrc/fp8_blockscale_gemm_sm90_binding.cu (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (2 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/fp8_blockscale.py (1 hunks)
  • tests/gemm/test_fp8_blockscale_gemm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fp8_blockscale_gemm_sm90_binding.cu
🧬 Code graph analysis (4)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (1)
  • fp8_blockscale_gemm_sm90 (3300-3542)
flashinfer/jit/gemm/fp8_blockscale.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/jit/cpp_ext.py (1)
  • is_cuda_version_at_least (86-87)
csrc/fp8_blockscale_gemm_sm90_binding.cu (2)
flashinfer/comm/cuda_ipc.py (1)
  • Function (37-40)
csrc/tvm_ffi_utils.h (2)
  • encode_dlpack_dtype (30-32)
  • get_stream (277-279)
flashinfer/gemm/gemm_base.py (2)
flashinfer/jit/gemm/fp8_blockscale.py (1)
  • gen_fp8_blockscale_gemm_sm90_module (10-54)
csrc/fp8_blockscale_gemm_sm90_binding.cu (8)
  • init (211-214)
  • init (211-211)
  • input (92-174)
  • input (92-93)
  • input_is_fp8 (78-90)
  • input_is_fp8 (78-79)
  • workspace (189-196)
  • workspace (189-189)
🪛 Ruff (0.14.8)
tests/gemm/test_fp8_blockscale_gemm.py

1-1: The file is executable but no shebang is present

(EXE002)


366-366: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


395-395: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

flashinfer/jit/gemm/fp8_blockscale.py

1-1: The file is executable but no shebang is present

(EXE002)


12-17: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

flashinfer/gemm/gemm_base.py

3383-3385: Avoid specifying long messages outside the exception class

(TRY003)


3389-3389: Avoid specifying long messages outside the exception class

(TRY003)


3391-3391: Avoid specifying long messages outside the exception class

(TRY003)


3397-3399: Avoid specifying long messages outside the exception class

(TRY003)


3404-3406: Avoid specifying long messages outside the exception class

(TRY003)


3416-3418: Avoid specifying long messages outside the exception class

(TRY003)


3423-3423: Avoid specifying long messages outside the exception class

(TRY003)


3428-3431: Avoid specifying long messages outside the exception class

(TRY003)


3433-3433: Avoid specifying long messages outside the exception class

(TRY003)


3435-3438: Avoid specifying long messages outside the exception class

(TRY003)


3441-3444: Avoid specifying long messages outside the exception class

(TRY003)


3446-3449: Avoid specifying long messages outside the exception class

(TRY003)


3453-3453: Avoid specifying long messages outside the exception class

(TRY003)


3460-3464: Avoid specifying long messages outside the exception class

(TRY003)


3466-3466: Avoid specifying long messages outside the exception class

(TRY003)


3469-3472: Avoid specifying long messages outside the exception class

(TRY003)


3474-3477: Avoid specifying long messages outside the exception class

(TRY003)


3482-3484: Avoid specifying long messages outside the exception class

(TRY003)


3486-3488: Avoid specifying long messages outside the exception class

(TRY003)


3490-3492: Avoid specifying long messages outside the exception class

(TRY003)


3498-3500: Avoid specifying long messages outside the exception class

(TRY003)


3533-3536: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (19)
csrc/fp8_blockscale_gemm_sm90_binding.cu (7)

1-26: LGTM!

The includes and is_fp8_e4m3fn helper are well-structured. The conditional compilation with FLASHINFER_ENABLE_FP8_E4M3 gracefully handles the case when FP8 support is disabled.


37-72: LGTM!

The class structure is well-designed with clear separation of concerns. The GetFunction dispatch correctly exposes the three required interfaces, and the documentation comment clearly lists the supported dtype combinations.


74-90: LGTM!

The runtime dtype dispatch correctly handles all four possible combinations with clear comments explaining why FP8 input + BF16 weight is unsupported.


113-140: LGTM!

The scale validation logic is thorough and handles both per-token and per-block scale formats correctly. The error messages are descriptive with expected vs actual shapes, which aids debugging.


156-173: LGTM!

The dual-path dispatch correctly handles W8A8 vs internal quantization paths with appropriate type casts. The comments explaining the two TRT-LLM method signatures are helpful for maintainability.


176-187: Consider wrapping workspace size queries in try-catch.

Based on learnings from this codebase, getWorkspaceSizeImpl calls can legitimately fail when probing configurations due to SMEM constraints. If getWorkspaceSizeBase can throw std::runtime_error, these calls should be wrapped in try-catch blocks.

Verify if getWorkspaceSizeBase can throw and consider:

   int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) {
     size_t max_size = 0;

-    max_size =
-        std::max(max_size, runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1));
-    max_size =
-        std::max(max_size, runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1));
-    max_size =
-        std::max(max_size, runner_fp8_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1));
+    // Swallow errors when SMEM exceeds maximum allowed
+    try {
+      max_size = std::max(max_size, runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1));
+    } catch (std::runtime_error&) {}
+    try {
+      max_size = std::max(max_size, runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1));
+    } catch (std::runtime_error&) {}
+    try {
+      max_size = std::max(max_size, runner_fp8_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1));
+    } catch (std::runtime_error&) {}

     return max_size;
   }

198-216: LGTM!

The member variables use unique_ptr for proper RAII memory management, and the module initialization correctly follows TVM FFI patterns.

flashinfer/jit/gemm/__init__.py (1)

30-44: LGTM!

The new import and __all__ entry follow the existing patterns in this file, correctly exposing the SM90 FP8 block-scale GEMM module generator.

flashinfer/gemm/__init__.py (1)

16-34: LGTM!

The new public API export follows the existing patterns consistently, correctly exposing fp8_blockscale_gemm_sm90 at the package level.

flashinfer/jit/gemm/fp8_blockscale.py (1)

19-54: LGTM!

The JIT spec is well-structured with comprehensive source files, include paths for TRT-LLM integration, and appropriate linker flags for CUDA runtime compilation support.

tests/gemm/test_fp8_blockscale_gemm.py (5)

48-57: LGTM!

The JIT warmup fixture correctly handles conditional compilation based on hardware support and JIT cache availability, ensuring tests have a stable GPU execution environment.


60-104: LGTM!

Well-structured parametrized test with appropriate hardware checks, clear assertions, and reasonable similarity thresholds for the BF16 internal quantization path.


106-189: LGTM with minor note.

Good test coverage for the supported dtype combinations. The TODO at line 182 about checking the threshold is noted - consider addressing or tracking this before the feature is finalized.


308-343: LGTM!

Good coverage of common LLM inference shapes with appropriate correctness validation.


399-430: LGTM!

Good test coverage for pre-allocated output buffer functionality, correctly verifying both buffer identity and output correctness.

flashinfer/gemm/gemm_base.py (4)

59-59: LGTM!

Import statement is correctly placed with other JIT gemm imports.


3293-3296: LGTM!

The cached runner initialization follows the established pattern in this file and correctly chains the JIT spec build, load, and init calls.


3513-3538: LGTM!

The input scale transformation correctly handles the layout conversion from user-provided (M, K//128) format to TRT-LLM's expected (K//128, M) format with proper padding and stride verification. The stride check at line 3532-3536 ensures memory layout compatibility.


3503-3542: LGTM!

The runner invocation correctly follows the pattern: workspace allocation, configuration, and GEMM execution. The argument order matches the C++ runGemm function signature.

Comment thread csrc/fp8_blockscale_gemm_sm90_binding.cu Outdated
Comment thread flashinfer/gemm/gemm_base.py
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py
Comment thread flashinfer/jit/gemm/fp8_blockscale.py
@jhaotingc jhaotingc requested a review from yongwww as a code owner December 11, 2025 04:50
@jhaotingc jhaotingc force-pushed the vchen/dg_swapab_linear branch from 794812c to d4d19f7 Compare December 12, 2025 03:52
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (3)
flashinfer/gemm/gemm_base.py (2)

59-60: Cache loader looks fine; include-dir wiring is reasonable.
One nit: _match_sm_version can’t ever match "90a" (it only builds "90", "100", etc.), so the later arch check effectively ignores the "90a" string. Consider switching to an explicit is_sm90a_supported(...) check (consistent with tests) or just checking "90".

Also applies to: 3293-3304


3525-3564: Clarify/validate the scale buffer layout and simplify the size math.
input_scale_size is computed via a non-obvious expression (Line 3529). If the intended layout is a flat (K_blocks * M_padded) buffer, this should be expressed directly (and ideally commented as “(K_blocks, M_padded) with ld=M_padded”). Also consider validating weight_scale.device == input.device (you validate this for input_scale but not for weight_scale).

@@
-        K_blocks = (K + BLOCK_SIZE - 1) // BLOCK_SIZE
-        input_scale_size = ((K * M_padded * 4 + BLOCK_SIZE - 1) // BLOCK_SIZE) // 4
+        K_blocks = K // BLOCK_SIZE
+        # Flat buffer representing (K_blocks, M_padded) with leading-dim M_padded.
+        input_scale_size = K_blocks * M_padded
@@
     if weight_is_fp8:
@@
         if weight_scale.dtype != torch.float32:
             raise ValueError(f"weight_scale must be float32, got {weight_scale.dtype}")
+        if weight_scale.device != weight.device:
+            raise ValueError(
+                f"weight_scale device mismatch. Expected {weight.device}, got {weight_scale.device}"
+            )
csrc/fp8_blockscale_gemm_sm90_binding.cu (1)

206-217: Workspace-size probing should be resilient to per-runner shape failures.
If getWorkspaceSizeBase(...) can throw for certain shapes/configs (e.g., SMEM constraints), consider catching and swallowing std::runtime_error per the established pattern in this repo (“Swallow errors when SMEM exceeds maximum allowed”). Based on learnings, this is considered acceptable for probing multiple candidate kernels.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 65dd3fdbe65c0416b058e525b66bd66930ea76d0 and d4d19f727815db03af3ab5273d249ee3a0b42582.

📒 Files selected for processing (3)
  • csrc/fp8_blockscale_gemm_sm90_binding.cu (1 hunks)
  • flashinfer/gemm/gemm_base.py (2 hunks)
  • tests/gemm/test_fp8_blockscale_gemm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
  • csrc/fp8_blockscale_gemm_sm90_binding.cu
🪛 Ruff (0.14.8)
flashinfer/gemm/gemm_base.py

3391-3393: Avoid specifying long messages outside the exception class

(TRY003)


3397-3397: Avoid specifying long messages outside the exception class

(TRY003)


3399-3399: Avoid specifying long messages outside the exception class

(TRY003)


3405-3407: Avoid specifying long messages outside the exception class

(TRY003)


3412-3414: Avoid specifying long messages outside the exception class

(TRY003)


3424-3426: Avoid specifying long messages outside the exception class

(TRY003)


3431-3431: Avoid specifying long messages outside the exception class

(TRY003)


3436-3439: Avoid specifying long messages outside the exception class

(TRY003)


3441-3441: Avoid specifying long messages outside the exception class

(TRY003)


3443-3446: Avoid specifying long messages outside the exception class

(TRY003)


3449-3452: Avoid specifying long messages outside the exception class

(TRY003)


3454-3457: Avoid specifying long messages outside the exception class

(TRY003)


3461-3461: Avoid specifying long messages outside the exception class

(TRY003)


3468-3472: Avoid specifying long messages outside the exception class

(TRY003)


3474-3474: Avoid specifying long messages outside the exception class

(TRY003)


3477-3480: Avoid specifying long messages outside the exception class

(TRY003)


3482-3485: Avoid specifying long messages outside the exception class

(TRY003)


3490-3492: Avoid specifying long messages outside the exception class

(TRY003)


3494-3496: Avoid specifying long messages outside the exception class

(TRY003)


3498-3500: Avoid specifying long messages outside the exception class

(TRY003)


3502-3504: Avoid specifying long messages outside the exception class

(TRY003)


3510-3512: Avoid specifying long messages outside the exception class

(TRY003)


3556-3559: Avoid specifying long messages outside the exception class

(TRY003)

tests/gemm/test_fp8_blockscale_gemm.py

1-1: The file is executable but no shebang is present

(EXE002)


425-425: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


454-454: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

🔇 Additional comments (1)
flashinfer/gemm/gemm_base.py (1)

3515-3524: Unable to locate the specified file (flashinfer/gemm/gemm_base.py) in the current repository state. The review comment cannot be fully verified without access to:

  • The Python file at the specified location (lines 3515-3524)
  • The C++ GEMM runner implementation (get_fp8_blockscale_gemm_runner_sm90)
  • The workspace configuration logic in C++

Critical discrepancy: The learning context indicates that in flashinfer GEMM implementations, it's acceptable to catch and silently ignore runtime_error exceptions when probing configurations due to SMEM constraints—suggesting that zero-workspace or failed-configuration scenarios may legitimately occur and be handled in the C++ layer rather than requiring Python-side guards like max(int(workspace_size), 1).

The review's core assumption (that unconditional configure_workspace calls with a minimum 1-byte allocation are necessary) conflicts with this documented pattern and requires verification against the actual C++ implementation behavior.

Comment on lines +1 to +13

#include <tvm/ffi/extra/module.h>

#include <cstddef>
#include <cstdint>
#include <functional>
#include <type_traits>
#include <vector>

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h"
#include "tvm_ffi_utils.h"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing required standard headers (may fail to compile depending on transitive includes).
You use std::unique_ptr/std::make_unique (Lines 42-49, 228-237), std::max (Lines 209-215), and std::numeric_limits (Lines 192-193) without including <memory>, <algorithm>, <limits>. Please add them explicitly.

@@
 #include <cstddef>
 #include <cstdint>
+#include <algorithm>
 #include <functional>
+#include <limits>
+#include <memory>
 #include <type_traits>
 #include <vector>

Also applies to: 38-50, 206-217

🤖 Prompt for AI Agents
In csrc/fp8_blockscale_gemm_sm90_binding.cu around lines 1 to 13, the file uses
std::unique_ptr/std::make_unique, std::max, and std::numeric_limits in later
code (refs: ~lines 42-49, 192-215) but does not explicitly include the required
standard headers; add #include <memory> for unique_ptr/make_unique, #include
<algorithm> for std::max, and #include <limits> for std::numeric_limits at the
top of this file (alongside the existing includes) so compilation does not rely
on transitive includes.

Comment on lines +98 to +182
void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output,
const Optional<TensorView>& scales_a, const Optional<TensorView>& scales_b) {
auto stream = get_stream(input.device());

auto input_ptr = input.data_ptr();
auto weight_ptr = weight.data_ptr();
auto output_ptr = output.data_ptr();

int64_t shape_m = input.size(0);
int64_t shape_k = input.size(1);
int64_t shape_n = weight.size(0);

TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null";
TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null";
TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null";
TVM_FFI_ICHECK(shape_k == weight.size(1)) << "K dimension mismatch";
TVM_FFI_ICHECK(shape_k % 16 == 0) << "N must be a multiple of 16, (K=" << shape_k << ")";
TVM_FFI_ICHECK(shape_n % 16 == 0) << "N must be a multiple of 16, (N=" << shape_n << ")";

// Determine dtypes for runner selection
bool input_is_fp8 = is_fp8_e4m3fn(input.dtype());
bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype());

// Validate scale requirements
if (input_is_fp8) {
TVM_FFI_ICHECK(scales_a.has_value() && scales_a.value().data_ptr() != nullptr)
<< "scales_a is required for FP8 input";
// TensorRT-LLM expects scale shape: (K/128, M) after transpose
// int64_t expected_scale_k = (shape_k + 127) / 128;
// TVM_FFI_ICHECK(scales_a.value().size(0) == expected_scale_k &&
// scales_a.value().size(1) == shape_m)
// << "scales_a shape mismatch: expected (" << expected_scale_k << ", " << shape_m
// << "), got (" << scales_a.value().size(0) << ", " << scales_a.value().size(1) << ")";
}

if (weight_is_fp8) {
TVM_FFI_ICHECK(scales_b.has_value() && scales_b.value().data_ptr() != nullptr)
<< "scales_b is required for FP8 weight";
// Validate scale shape: should be (N, K/128) for per-token or (N/128, K/128) for per-block
int64_t expected_scale_k = (shape_k + 127) / 128;
int64_t scale_dim0 = scales_b.value().size(0);
int64_t scale_dim1 = scales_b.value().size(1);

bool is_per_token = (scale_dim0 == shape_n && scale_dim1 == expected_scale_k);
bool is_per_block = (scale_dim0 == (shape_n + 127) / 128 && scale_dim1 == expected_scale_k);

TVM_FFI_ICHECK(is_per_token || is_per_block)
<< "scales_b shape mismatch: expected (" << shape_n << ", " << expected_scale_k
<< ") for per-token or (" << ((shape_n + 127) / 128) << ", " << expected_scale_k
<< ") for per-block, got (" << scale_dim0 << ", " << scale_dim1 << ")";
}

// Extract scale pointers
float const* scales_a_ptr = scales_a.has_value()
? reinterpret_cast<float const*>(scales_a.value().data_ptr())
: nullptr;
float const* scales_b_ptr = scales_b.has_value()
? reinterpret_cast<float const*>(scales_b.value().data_ptr())
: nullptr;

// Select appropriate runner
auto* runner = selectRunner(input_is_fp8, weight_is_fp8);
TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination";
TVM_FFI_ICHECK(workspace_ != nullptr)
<< "Workspace not configured. Call configure_workspace first.";

// TensorRT-LLM has two gemm() methods:
// 1. gemm(void*, ...) - for internal quantization (BF16 inputs)
// 2. gemm(__nv_fp8_e4m3*, int, __nv_fp8_e4m3*, int, ...) - for pre-quantized FP8 inputs
if (input_is_fp8 && weight_is_fp8) {
// W8A8: Use the pre-quantized FP8 path
auto* fp8_input = reinterpret_cast<__nv_fp8_e4m3*>(input_ptr);
auto* fp8_weight = reinterpret_cast<__nv_fp8_e4m3*>(weight_ptr);
auto* bf16_output = reinterpret_cast<__nv_bfloat16*>(output_ptr);

runner->gemm(fp8_input, shape_k, // input with leading dimension
fp8_weight, shape_k, // weight with leading dimension
bf16_output, shape_n, // output with leading dimension
shape_m, shape_n, shape_k, scales_a_ptr, scales_b_ptr, stream);
} else {
// BF16+BF16 or BF16+FP8: Use internal quantization path
runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, stream,
scales_a_ptr, scales_b_ptr);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find and read the actual file
find . -name "fp8_blockscale_gemm_sm90_binding.cu" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 110


🏁 Script executed:

#!/bin/bash
# Check the file structure
fd "fp8_blockscale_gemm_sm90_binding.cu" --type f

Repository: flashinfer-ai/flashinfer

Length of output: 108


🏁 Script executed:

#!/bin/bash
# Search for the file in common locations
find . -path "*/csrc/*" -name "*fp8*blockscale*" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 624


🏁 Script executed:

#!/bin/bash
# Broader search
git ls-files | grep -i "fp8.*blockscale\|blockscale.*binding"

Repository: flashinfer-ai/flashinfer

Length of output: 913


🏁 Script executed:

#!/bin/bash
# List all files matching the pattern
git ls-files "**/*fp8_blockscale*" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 913


🏁 Script executed:

#!/bin/bash
# Find any references to this filename
rg "fp8_blockscale_gemm_sm90_binding" --type cu

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Find all .cu files in csrc
find ./csrc -name "*.cu" -type f 2>/dev/null | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 667


🏁 Script executed:

#!/bin/bash
# Search for blockscale in cu files
rg "blockscale" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Check if csrc directory exists
ls -la csrc/ 2>/dev/null | head -20 || echo "csrc not found in root"
# Try alternative structure
find . -type d -name csrc 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 1515


🏁 Script executed:

#!/bin/bash
# Search for similar error message patterns and K%16 checks
rg "N must be a multiple of 16" --type cu

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Search for shape_k % 16 checks
rg "shape_k.*%.*16" --type cu -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Search for output dtype validation patterns
rg "output.*dtype\|dtype.*output" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Search for bfloat16 or BF16 checks
rg "bfloat16|bf16" --type cu -i | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for TVM FFI utilities for dtype checking
find . -name "tvm_ffi_utils.h" -o -name "*tvm*utils*.h" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 90


🏁 Script executed:

#!/bin/bash
# Find other GEMM binding files
find . -type f -name "*gemm*binding*" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 311


🏁 Script executed:

#!/bin/bash
# Search in the referenced header file from learnings
cat include/flashinfer/gemm/bf16_gemm_cutlass_template.h 2>/dev/null | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Search all files that validate output dtype
rg "encode_dlpack_dtype|dtype.*code\|TVM_FFI_ICHECK.*dtype" --type cu -A2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for how dtype is validated in binding files
rg "dtype\(\)" --type cu -B2 -A2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Find files using TVM_FFI_ICHECK
rg "TVM_FFI_ICHECK" --type cu | wc -l

Repository: flashinfer-ai/flashinfer

Length of output: 113


🏁 Script executed:

#!/bin/bash
# Get actual usage of TVM_FFI_ICHECK with dtype
rg "TVM_FFI_ICHECK.*dtype" --type cu -B1 -A1 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for blockscale references in Python files
rg "blockscale" --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2030


🏁 Script executed:

#!/bin/bash
# Search for any fp8_blockscale references across all file types
rg "fp8_blockscale" -i | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 3152


🏁 Script executed:

#!/bin/bash
# Show top-level directory structure
ls -la | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1257


🏁 Script executed:

#!/bin/bash
# Find all .cu files
find . -name "*.cu" -type f 2>/dev/null | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 970


🏁 Script executed:

#!/bin/bash
# Check git ls-files for cu files
git ls-files "*.cu" | grep -E "fp8|blockscale" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 782


🏁 Script executed:

#!/bin/bash
# All cu files in the repo
git ls-files "*.cu" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 910


🏁 Script executed:

#!/bin/bash
# Look for any mentions of shape_k or shape_n dimension checks
rg "shape_k.*shape_n|K.*multiple.*16|N.*multiple.*16" --type cu -B1 -A1 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Search for files with "binding" in the name
find . -name "*binding*" -type f 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 1184


🏁 Script executed:

#!/bin/bash
# Get complete cu file list
git ls-files | grep "\.cu$"

Repository: flashinfer-ai/flashinfer

Length of output: 6582


🏁 Script executed:

#!/bin/bash
# Search for files containing the actual code from the review
rg "N must be a multiple of 16"

Repository: flashinfer-ai/flashinfer

Length of output: 337


🏁 Script executed:

#!/bin/bash
# Search for shape_k % 16 with context
rg "shape_k.*%.*16" -A1 -B1

Repository: flashinfer-ai/flashinfer

Length of output: 451


🏁 Script executed:

#!/bin/bash
# Search for runner templates and output types
rg "runner.*gemm|W8A8.*output" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for BF16 output validation
rg "__nv_bfloat16|output.*bf16" -i --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Verify the logic: check which dimension is checked with %16
cat > /tmp/verify_logic.py << 'EOF'
# The review states:
# Line 114 checks: shape_k % 16 == 0
# But message says: "N must be a multiple of 16, (K=" << shape_k << ")"
#
# This is indeed wrong - should be "K must be..."
# The check is on shape_k (the K dimension in matrix multiply)
# But the error message says "N" 

print("Issue 1: Error message mismatch")
print("Code checks: shape_k % 16 == 0")
print("Message says: 'N must be a multiple of 16'")
print("Should say: 'K must be a multiple of 16'")
print("Status: CONFIRMED - this is a clear bug")
print()
print("Issue 2: Output BF16 validation")
print("Review claims runners are BF16-output only")
print("Code shows: auto* bf16_output = reinterpret_cast<__nv_bfloat16*>(output_ptr)")
print("This suggests BF16 output is assumed")
print("But no validation is done on output.dtype()")
print("If caller passes FP16 or other type, memory corruption could occur")
print("Status: NEEDS VERIFICATION - is this actually a safety issue?")
EOF
python /tmp/verify_logic.py

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

#!/bin/bash
# Search for how other bindings handle output dtype validation
rg "output.*dtype\(\)" --type cu -B3 -A3 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for available dtype checking functions/macros
rg "encode_dlpack_dtype|is_fp8_e4m3fn|dtype.*code" -B1 -A1 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 4074


🏁 Script executed:

#!/bin/bash
# Check what dtype checking utilities exist
rg "bool.*dtype|is_.*dtype" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Search for dtype validation patterns in bindings
rg "TVM_FFI_ICHECK.*dtype\|is_fp8_e4m3fn" --type cu -B2 -A2 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look at is_fp8_e4m3fn usage
rg "is_fp8_e4m3fn" -B1 -A1

Repository: flashinfer-ai/flashinfer

Length of output: 847


🏁 Script executed:

#!/bin/bash
# Search for bfloat16_code or similar dtype codes
rg "bfloat16_code|e4m3fn_code|dtype.*code" -B1 -A1 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 3614


🏁 Script executed:

/#search for how dtype is compared
rg "\.dtype()\|input\.dtype\(\)|weight\.dtype\(\)" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 236


🏁 Script executed:

#!/bin/bash
# Search for bfloat16 type checking
rg "is_bfloat16|bfloat16|bf16" -i --type cu | grep -E "is_|check|valid" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for how types are compared in existing code
rg "dtype().*==" --type cu -B1 -A1 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Check all cu files for dtype comparison patterns
rg "encode_dlpack_dtype|\.code\(\)" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for dtype constants or enums
rg "kDLFloat16|kDLBfloat" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 153


🏁 Script executed:

#!/bin/bash
# Get a comprehensive view of all dtype-related validations
rg "shape_[a-z] % 16\|N must be\|K must be" -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Verify the logic of the error message issue
python3 << 'EOF'
# Issue 1: Error Message on Line 114
# The code checks: TVM_FFI_ICHECK(shape_k % 16 == 0) << "N must be..."
# But it's checking shape_k (K dimension)
# Message says "N must be a multiple of 16" 
# This is clearly wrong - message should say "K must be..."

# This is a straightforward factual error in the message

# Issue 2: Output dtype validation
# The code assumes output is BF16 (reinterpret_cast<__nv_bfloat16*>)
# But there's no validation that output.dtype() is actually BF16
# If someone passes FP16 or FP32, undefined behavior/memory corruption

# Both issues appear valid in principle
# Issue 1 is definitely a bug (wrong error message)
# Issue 2 is a type safety concern (missing validation)

print("Issue 1 (Error message): VALID BUG")
print("Issue 2 (Type validation): VALID CONCERN if runners are truly BF16-only")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

#!/bin/bash
# Check for dtype validation in similar files
rg "is_fp8_e4m3fn" --type cu -B5 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for pattern: how to check if output is BF16
rg "bfloat16|bf16" --type cu -i -B2 -A2 | grep -E "TVM_FFI_ICHECK|if.*dtype|is_" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Search for any dtype comparison that might give us a pattern
rg "dtype\(\).*==" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

/#Check for type_code or similar
rg "type_code|code\(\)" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 148


🏁 Script executed:

#!/bin/bash
# Find function definitions for dtype checking
rg "^bool is_|^bool.*dtype" --type cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for is_fp8_e4m3fn definition or usage context
rg "is_fp8_e4m3fn" -B3 -A3 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1403


🏁 Script executed:

#!/bin/bash
# Final verification: Check if there's any existing dtype validation for output
rg "output\)\|TVM_FFI_ICHECK.*output" --type cu | grep dtype | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Verify the K% 16 error message one more time with exact pattern
rg 'shape_k.*%.*16.*0.*<<.*"N must' --type cu

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

head -n 200 csrc/fp8_blockscale_gemm_sm90_binding.cu

Repository: flashinfer-ai/flashinfer

Length of output: 8648


🏁 Script executed:

sed -n '1,50p' csrc/fp8_blockscale_gemm_sm90_binding.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1528


🏁 Script executed:

sed -n '110,125p' csrc/fp8_blockscale_gemm_sm90_binding.cu

Repository: flashinfer-ai/flashinfer

Length of output: 937


🏁 Script executed:

cat csrc/tvm_ffi_utils.h | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4489


🏁 Script executed:

sed -n '98,182p' csrc/fp8_blockscale_gemm_sm90_binding.cu

Repository: flashinfer-ai/flashinfer

Length of output: 4471


🏁 Script executed:

grep -n "encode_dlpack_dtype\|kDLBfloat\|bfloat16_code" csrc/fp8_blockscale_gemm_sm90_binding.cu

Repository: flashinfer-ai/flashinfer

Length of output: 129


🏁 Script executed:

grep -n "encode_dlpack_dtype\|DLDataType\|kDL" csrc/tvm_ffi_utils.h | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 2158


Fix error message on line 114 and add output BF16 dtype validation.

Line 114 checks shape_k % 16 == 0 (the K dimension) but the error message incorrectly says "N must be a multiple of 16"—it should say "K must be a multiple of 16" to match what is being checked.

Additionally, the code casts output_ptr to __nv_bfloat16* without validating that output.dtype() is actually BF16. If a caller passes FP16 or another type, this will cause memory corruption. Use the bfloat16_code constant already available in tvm_ffi_utils.h (and the encode_dlpack_dtype() pattern already used in this file) to validate the output dtype.

-    TVM_FFI_ICHECK(shape_k % 16 == 0) << "N must be a multiple of 16, (K=" << shape_k << ")";
+    TVM_FFI_ICHECK(shape_k % 16 == 0) << "K must be a multiple of 16 (K=" << shape_k << ")";
     TVM_FFI_ICHECK(shape_n % 16 == 0) << "N must be a multiple of 16, (N=" << shape_n << ")";
+
+    // Output must be BF16 for these runner instantiations.
+    TVM_FFI_ICHECK(encode_dlpack_dtype(output.dtype()) == bfloat16_code)
+        << "output must be BF16 (bfloat16) for fp8_blockscale_gemm_sm90";
🤖 Prompt for AI Agents
In csrc/fp8_blockscale_gemm_sm90_binding.cu around lines 98 to 182, change the
incorrect error text on the shape_k modulus check to mention "K must be a
multiple of 16" instead of "N", and add a defensive dtype check before casting
output_ptr to __nv_bfloat16*: use the existing
encode_dlpack_dtype(output.dtype()) comparison against the bfloat16_code
constant (from tvm_ffi_utils.h) and assert with TVM_FFI_ICHECK that the output
dtype is BF16; perform this validation prior to any reinterpret_cast to
__nv_bfloat16* so we avoid potential memory corruption.

Comment on lines +3306 to +3514

# Validate dtype combinations
input_is_fp8 = input.dtype == torch.float8_e4m3fn
weight_is_fp8 = weight.dtype == torch.float8_e4m3fn
input_is_bf16 = input.dtype == torch.bfloat16
weight_is_bf16 = weight.dtype == torch.bfloat16

# Explicitly reject FP8 input + BF16 weight (missing kernel implementation)
if input_is_fp8 and weight_is_bf16:
raise ValueError(
"FP8 input + BF16 weight is not supported (missing kernel implementation). "
)

# Validate scale requirements for FP8 inputs
if input_is_fp8:
if input_scale is None:
raise ValueError("input_scale is required when input is FP8. ")
# Users provide input_scale in shape (M, K//128), matching per_token_cast_to_fp8 output.
# We transpose it internally to (K//128, M) to match TensorRT-LLM kernel expectations.
expected_scale_shape = (M, K // BLOCK_SIZE)
if input_scale.shape != expected_scale_shape:
raise ValueError(
f"input_scale shape mismatch. Expected {expected_scale_shape}, "
f"got {input_scale.shape}"
)
if input_scale.dtype != torch.float32:
raise ValueError(f"input_scale must be float32, got {input_scale.dtype}")
if input_scale.device != input.device:
raise ValueError(
f"input_scale device mismatch. Expected {input.device}, "
f"got {input_scale.device}"
)
else:
if not input_is_bf16:
raise ValueError(
f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), "
f"got {input.dtype}"
)
if input_scale is not None:
raise ValueError(
"input_scale should not be provided for BF16 inputs. "
"Use FP8 inputs if you want to provide external scales."
)

if weight_is_fp8:
if weight_scale is None:
raise ValueError("weight_scale is required when weight is FP8. ")
expected_per_token_shape = (N, K // BLOCK_SIZE)
expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE)
is_per_token = weight_scale.shape == expected_per_token_shape
is_per_block = weight_scale.shape == expected_per_block_shape

if not (is_per_token or is_per_block):
raise ValueError(
f"weight_scale shape mismatch. Expected either {expected_per_token_shape} "
f"(per-token, 1x128 blocks) or {expected_per_block_shape} "
f"(per-block, 128x128 blocks), got {weight_scale.shape}"
)
if weight_scale.dtype != torch.float32:
raise ValueError(f"weight_scale must be float32, got {weight_scale.dtype}")
else:
if not weight_is_bf16:
raise ValueError(
f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), "
f"got {weight.dtype}"
)
if weight_scale is not None:
raise ValueError(
"weight_scale should not be provided for BF16 weights. "
"Use FP8 weights if you want to provide external scales."
)

# Validate output tensor if provided
if out is not None:
if out.shape != (M, N):
raise ValueError(
f"Output shape mismatch. Expected ({M}, {N}), got {out.shape}"
)
if out.device != input.device:
raise ValueError(
f"Output device mismatch. Expected {input.device}, got {out.device}"
)
if out.dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Output dtype must be torch.bfloat16 or torch.float16, got {out.dtype}"
)
if out_dtype is not None and out.dtype != out_dtype:
raise ValueError(
f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}"
)
out_dtype = out.dtype
else:
# Allocate output
out_dtype = out_dtype or torch.bfloat16
if out_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}"
)
out = torch.empty(M, N, dtype=out_dtype, device=input.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Output dtype contract is inconsistent with the C++ binding (likely BF16-only).
The C++ runner instantiations use __nv_bfloat16 output and the binding comment says “Output is always BF16”, but Python accepts/allocates FP16 outputs (Lines 3497-3513). That can lead to writing BF16 bits into an FP16 tensor (silent corruption).

@@
-    if out is not None:
+    if out is not None:
@@
-        if out.dtype not in [torch.bfloat16, torch.float16]:
+        if out.dtype not in [torch.bfloat16]:
             raise ValueError(
-                f"Output dtype must be torch.bfloat16 or torch.float16, got {out.dtype}"
+                f"Output dtype must be torch.bfloat16, got {out.dtype}"
             )
@@
-    else:
+    else:
         # Allocate output
         out_dtype = out_dtype or torch.bfloat16
-        if out_dtype not in [torch.bfloat16, torch.float16]:
+        if out_dtype not in [torch.bfloat16]:
             raise ValueError(
-                f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}"
+                f"Output dtype must be torch.bfloat16, got {out_dtype}"
             )

If FP16 output is actually desired, the fix should be on the C++ side (add an FP16-output runner + dtype dispatch) rather than allowing it only in Python.

🧰 Tools
🪛 Ruff (0.14.8)

3391-3393: Avoid specifying long messages outside the exception class

(TRY003)


3397-3397: Avoid specifying long messages outside the exception class

(TRY003)


3399-3399: Avoid specifying long messages outside the exception class

(TRY003)


3405-3407: Avoid specifying long messages outside the exception class

(TRY003)


3412-3414: Avoid specifying long messages outside the exception class

(TRY003)


3424-3426: Avoid specifying long messages outside the exception class

(TRY003)


3431-3431: Avoid specifying long messages outside the exception class

(TRY003)


3436-3439: Avoid specifying long messages outside the exception class

(TRY003)


3441-3441: Avoid specifying long messages outside the exception class

(TRY003)


3443-3446: Avoid specifying long messages outside the exception class

(TRY003)


3449-3452: Avoid specifying long messages outside the exception class

(TRY003)


3454-3457: Avoid specifying long messages outside the exception class

(TRY003)


3461-3461: Avoid specifying long messages outside the exception class

(TRY003)


3468-3472: Avoid specifying long messages outside the exception class

(TRY003)


3474-3474: Avoid specifying long messages outside the exception class

(TRY003)


3477-3480: Avoid specifying long messages outside the exception class

(TRY003)


3482-3485: Avoid specifying long messages outside the exception class

(TRY003)


3490-3492: Avoid specifying long messages outside the exception class

(TRY003)


3494-3496: Avoid specifying long messages outside the exception class

(TRY003)


3498-3500: Avoid specifying long messages outside the exception class

(TRY003)


3502-3504: Avoid specifying long messages outside the exception class

(TRY003)


3510-3512: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
flashinfer/gemm/gemm_base.py around lines 3497-3513: the Python API currently
accepts/allocates torch.float16 outputs even though the C++ runner/binding is
BF16-only, risking silent corruption. Fix by restricting outputs to
torch.bfloat16: validate that provided out.dtype (and out_dtype when supplied)
is torch.bfloat16 and raise an error otherwise; when allocating out, always
allocate with dtype=torch.bfloat16 (ignore or reject requests for float16). If
FP16 outputs are actually required, implement the FP16 runner and dtype dispatch
in the C++ bindings instead of allowing FP16 only in Python.

Comment on lines +48 to +58
@pytest.fixture(
autouse=not has_flashinfer_jit_cache(),
scope="module",
)
def warmup_jit():
"""Warm up JIT compilation for FP8 block-scale GEMM if not cached."""
if is_sm90a_supported(torch.device("cuda:0")):
jit_specs = [gen_fp8_blockscale_gemm_sm90_module()]
flashinfer.jit.build_jit_specs(jit_specs, verbose=False)
yield

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Warmup fixture should guard CUDA availability to avoid import-time failures.
Because it’s autouse, it may execute before individual tests can skip(). Add a torch.cuda.is_available() guard (or skip) before touching cuda:0.

@@
 def warmup_jit():
     """Warm up JIT compilation for FP8 block-scale GEMM if not cached."""
+    if not torch.cuda.is_available():
+        yield
+        return
     if is_sm90a_supported(torch.device("cuda:0")):
         jit_specs = [gen_fp8_blockscale_gemm_sm90_module()]
         flashinfer.jit.build_jit_specs(jit_specs, verbose=False)
     yield
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.fixture(
autouse=not has_flashinfer_jit_cache(),
scope="module",
)
def warmup_jit():
"""Warm up JIT compilation for FP8 block-scale GEMM if not cached."""
if is_sm90a_supported(torch.device("cuda:0")):
jit_specs = [gen_fp8_blockscale_gemm_sm90_module()]
flashinfer.jit.build_jit_specs(jit_specs, verbose=False)
yield
@pytest.fixture(
autouse=not has_flashinfer_jit_cache(),
scope="module",
)
def warmup_jit():
"""Warm up JIT compilation for FP8 block-scale GEMM if not cached."""
if not torch.cuda.is_available():
yield
return
if is_sm90a_supported(torch.device("cuda:0")):
jit_specs = [gen_fp8_blockscale_gemm_sm90_module()]
flashinfer.jit.build_jit_specs(jit_specs, verbose=False)
yield
🤖 Prompt for AI Agents
In tests/gemm/test_fp8_blockscale_gemm.py around lines 48 to 58, the autouse
fixture warmup_jit can touch torch.device("cuda:0") during import even when CUDA
is unavailable; update the fixture to first check torch.cuda.is_available() (or
call pytest.skip) and return/skip early if CUDA is not available, then proceed
to call is_sm90a_supported and build the JIT specs only when CUDA is present to
avoid import-time failures.

Comment on lines +404 to +456
def test_fp8_blockscale_gemm_error_handling():
"""Test that proper errors are raised for invalid inputs."""
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] < 9:
pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later")

if not is_sm90a_supported(torch.device("cuda")):
pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support")

device = "cuda"
m, n, k = 16, 256, 256

# Test: K not divisible by 128
input = torch.randn(m, 127, device=device, dtype=torch.bfloat16)
weight = torch.randn(n, 127, device=device, dtype=torch.bfloat16)
with pytest.raises(ValueError, match="divisible by block size"):
fp8_blockscale_gemm_sm90(input, weight)

# Test: FP16 not supported
input = torch.randn(m, k, device=device, dtype=torch.float16)
weight = torch.randn(n, k, device=device, dtype=torch.float16)
with pytest.raises(ValueError, match="FP8.*or BF16"):
fp8_blockscale_gemm_sm90(input, weight)

# Test: FP8 weight without scale (naive conversion)
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16)
weight_fp8_naive = weight_bf16.to(torch.float8_e4m3fn)
with pytest.raises(ValueError, match="weight_scale is required when weight is FP8"):
fp8_blockscale_gemm_sm90(input_bf16, weight_fp8_naive, None, None)

# Test: BF16 input with scale (should raise error)
input = torch.randn(m, k, device=device, dtype=torch.bfloat16)
weight = torch.randn(n, k, device=device, dtype=torch.bfloat16)
fake_scale = torch.ones(m, k // 128, device=device, dtype=torch.float32)
with pytest.raises(ValueError, match="input_scale should not be provided for BF16"):
fp8_blockscale_gemm_sm90(input, weight, input_scale=fake_scale)

# Test: Wrong scale shape for FP8 input
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
input_fp8, _ = per_token_cast_to_fp8(input_bf16)
weight = torch.randn(n, k, device=device, dtype=torch.bfloat16)
wrong_scale = torch.ones(m, k // 64, device=device, dtype=torch.float32)
with pytest.raises(ValueError):
fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale=wrong_scale)

# Test: FP8 input + BF16 weight is NOT supported
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
input_fp8, input_scale = per_token_cast_to_fp8(input_bf16)
weight = torch.randn(n, k, device=device, dtype=torch.bfloat16)
with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"):
fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale, None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use raw strings / escape regex metacharacters in pytest.raises(..., match=...).
Ruff’s RUF043 is valid here: patterns like "FP8.*or BF16" should be raw strings to avoid unintended escapes / regex surprises.

@@
-    with pytest.raises(ValueError, match="FP8.*or BF16"):
+    with pytest.raises(ValueError, match=r"FP8.*or BF16"):
         fp8_blockscale_gemm_sm90(input, weight)
@@
-    with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"):
+    with pytest.raises(ValueError, match=r"FP8 input.*BF16 weight.*not supported"):
         fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale, None)
🧰 Tools
🪛 Ruff (0.14.8)

425-425: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


454-454: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

🤖 Prompt for AI Agents
In tests/gemm/test_fp8_blockscale_gemm.py around lines 404 to 456, the regex
strings passed to pytest.raises(..., match=...) are regular Python strings and
should be raw strings to avoid accidental escapes or unintended regex behavior;
update the match arguments to use raw string literals (prefix with r'...') for
patterns such as "FP8.*or BF16" and "FP8 input.*BF16 weight.*not supported" (and
any other match patterns that include backslashes or regex metacharacters), or
alternatively escape literal regex metacharacters if you intend a literal match.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Dec 12, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !192 has been created, and the CI pipeline #40112860 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #40112860: canceled

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @katec846 would you mind fixing the pre-commits issues?

<< "scales_a is required for FP8 input";
// TensorRT-LLM expects scale shape: (K/128, M) after transpose
// int64_t expected_scale_k = (shape_k + 127) / 128;
// TVM_FFI_ICHECK(scales_a.value().size(0) == expected_scale_k &&
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we comment these out?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Zihao, this was added before porting the quant kernel but then relaxed due to TRTLLM internally is expecting a 1-d tensor. Thanks!

Comment thread tests/gemm/test_fp8_blockscale_gemm.py Outdated
Returns diff = 1 - sim, where sim = 2*<x,y> / (||x||² + ||y||²)
This is similar to cosine similarity but uses squared norms in denominator.

diff < 0.001 corresponds to >99.9% similarity.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this mechanism designed to be aligned with trtllm?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I think this was added by previous developer, but was never used.
I changed the criteria to cosine similarity same as VLLM and passed all test with 99.9% similarity (both VLLM and FlashInfer).

@@ -1163,10 +1163,13 @@ __global__ void convert_kernel(OutputType* output, InputType const* const input,
}

static int kNumDeviceSMs = -1;
static bool kDeepGemmEnabled = []() -> bool {
// a function that returns kDeepGemmEnabled
bool getDeepGemmEnabled() {
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 Dec 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this change, it make sense to me.

For more background, @katec846 found that we will skip deepgemm silently after #2090 because kDeepGemmEnabled returns false, in the future I think we should display these informations in the logging (e.g. whether deep gemm is enabled or not).
cc @bkryu

@jhaotingc jhaotingc requested a review from jimmyzho as a code owner December 16, 2025 03:28
@jhaotingc jhaotingc force-pushed the vchen/dg_swapab_linear branch from 5d74d09 to 5e5d685 Compare December 16, 2025 03:32
xuanzic and others added 8 commits December 15, 2025 19:32
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
@jhaotingc jhaotingc force-pushed the vchen/dg_swapab_linear branch from 5e5d685 to 72c4f7b Compare December 16, 2025 03:33
@jhaotingc jhaotingc force-pushed the vchen/dg_swapab_linear branch from 72c4f7b to a4ae22c Compare December 16, 2025 03:48
@jhaotingc
Copy link
Copy Markdown
Contributor

Hi @yzh119 thanks for reviewing. This PR is considered ready now.

I found and noted in vllm-project/vllm#29213 (comment) that with my implementation DS-R1 has an accuracy degradation. I've cross-referenced this PR and TRTLLM's implementation, as well as VLLM's FP8 Linear op unit test, there seemed to be no error at all in this FlashInfer PR. There's some part in DS-R1 that's having accuracy diff in VLLM.
For the FlashInfer part, this is ready.

Thank you again!!

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
@jhaotingc jhaotingc force-pushed the vchen/dg_swapab_linear branch from a4ae22c to c2e0540 Compare December 17, 2025 03:02
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Dec 17, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !192 has been updated with latest changes, and the CI pipeline #40345187 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! The failed UTs are not relevant and let's merge it now.

@yzh119 yzh119 merged commit 15a819e into flashinfer-ai:main Dec 17, 2025
4 checks passed
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR is adding W8A8 on top of
[PR2101](flashinfer-ai#2101)

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* FP8 block-scale GEMM for SM90 (Hopper): supports BF16/FP8 input
combinations, per-token/per-block scales, runtime-dtype dispatch,
workspace configuration, and robust input/output validation.
* **Chores**
* Exposed SM90 FP8 GEMM runner in package APIs and added a JIT module
generator to build/load the SM90 kernel.
* **Tests**
* Added comprehensive tests for shapes, dtypes, scales, preallocated
outputs, correctness thresholds, hardware guards, and error handling.
* **Refactor**
  * DeepGemm enablement now evaluated at runtime for dynamic behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: Vivian Chen <140748220+xuanzic@users.noreply.github.com>
Co-authored-by: Vivian Chen <xuanzic@nvidia.com>
Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants