make DeepGEMM swapAB available for linear gemm SM90#2131
make DeepGEMM swapAB available for linear gemm SM90#2131yzh119 merged 9 commits intoflashinfer-ai:mainfrom
Conversation
WalkthroughAdds an SM90-optimized FP8 block-scale GEMM path: a CUDA TVM-FFI binding and runner with runtime dtype dispatch and workspace management, Python JIT spec and high-level API wiring, unit tests, and a runtime DeepGemm enablement accessor. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User as User/App
participant PyAPI as Python API\n(fp8_blockscale_gemm_sm90)
participant JIT as JIT Module Loader
participant Runner as SM90 Runner\n(Fp8BlockScaleGemmRunner)
participant Kernel as Cutlass Kernel
User->>PyAPI: call fp8_blockscale_gemm_sm90(input, weight, scales...)
PyAPI->>PyAPI: validate arch, dtypes, K%128, shapes, scales
PyAPI->>JIT: gen_fp8_blockscale_gemm_sm90_module() -> Module (if needed)
PyAPI->>Runner: configure_workspace(workspace_buffer)
PyAPI->>Runner: run_gemm(input_ptr, weight_ptr, output_ptr, scale_ptrs, shapes)
Runner->>Runner: selectRunner(input_is_fp8, weight_is_fp8)
Runner->>Kernel: launch selected Cutlass GEMM (workspace, scales)
Kernel-->>Runner: kernel completes
Runner-->>PyAPI: status / result
PyAPI-->>User: return output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📥 CommitsReviewing files that changed from the base of the PR and between 72c4f7bf4e54d69bf22b488cfc2683547cc9e66d and c2e0540. 📒 Files selected for processing (3)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-11-12T03:35:17.583ZApplied to files:
🧬 Code graph analysis (3)csrc/fp8_blockscale_gemm_sm90_binding.cu (1)
flashinfer/gemm/gemm_base.py (2)
tests/gemm/test_fp8_blockscale_gemm.py (5)
🪛 Ruff (0.14.8)flashinfer/gemm/gemm_base.py3391-3393: Avoid specifying long messages outside the exception class (TRY003) 3397-3397: Avoid specifying long messages outside the exception class (TRY003) 3399-3399: Avoid specifying long messages outside the exception class (TRY003) 3405-3407: Avoid specifying long messages outside the exception class (TRY003) 3412-3414: Avoid specifying long messages outside the exception class (TRY003) 3417-3417: Avoid specifying long messages outside the exception class (TRY003) 3427-3429: Avoid specifying long messages outside the exception class (TRY003) 3434-3434: Avoid specifying long messages outside the exception class (TRY003) 3436-3436: Avoid specifying long messages outside the exception class (TRY003) 3438-3441: Avoid specifying long messages outside the exception class (TRY003) 3444-3447: Avoid specifying long messages outside the exception class (TRY003) 3449-3452: Avoid specifying long messages outside the exception class (TRY003) 3456-3456: Avoid specifying long messages outside the exception class (TRY003) 3463-3467: Avoid specifying long messages outside the exception class (TRY003) 3469-3469: Avoid specifying long messages outside the exception class (TRY003) 3472-3475: Avoid specifying long messages outside the exception class (TRY003) 3477-3480: Avoid specifying long messages outside the exception class (TRY003) 3485-3487: Avoid specifying long messages outside the exception class (TRY003) 3489-3491: Avoid specifying long messages outside the exception class (TRY003) 3493-3495: Avoid specifying long messages outside the exception class (TRY003) 3497-3499: Avoid specifying long messages outside the exception class (TRY003) 3505-3507: Avoid specifying long messages outside the exception class (TRY003) tests/gemm/test_fp8_blockscale_gemm.py1-1: The file is executable but no shebang is present (EXE002) 326-326: Pattern passed to (RUF043) 355-355: Pattern passed to (RUF043) Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
8086118 to
9bbf63f
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (6)
tests/gemm/test_fp8_blockscale_gemm.py (5)
1-15: Minor: Copyright year is 2024 for new 2025 code.The copyright header shows 2024, but this is new code being added in 2025. Consider updating to 2025.
32-45: Unused utility functioncalc_diff.This function is defined but never called in any of the tests - they all use
F.cosine_similarityinstead. Consider either using this function for consistency with TRT-LLM's metric (as documented in the docstring), or removing it to avoid dead code.
234-261: Consider vectorizing dequantization and removing print statement.
- The nested Python loops (lines 235-248) could be vectorized using tensor operations for better readability and performance:
# Vectorized dequantization input_dequant = input_fp8.to(torch.bfloat16).view(m, k // 128, 128) * input_scale.unsqueeze(-1) input_dequant = input_dequant.view(m, k)
- The
264-306: LGTM!Good test coverage for per-block weight scales. Minor suggestion: consider removing the print statement at line 305 for cleaner test output.
366-396: Use raw strings for regex patterns inpytest.raises(match=...).The patterns contain regex metacharacters (
.and*). While they work as intended, using raw strings (r"...") makes the intent clearer and follows Python best practices for regex patterns.- with pytest.raises(ValueError, match="FP8.*or BF16"): + with pytest.raises(ValueError, match=r"FP8.*or BF16"): fp8_blockscale_gemm_sm90(input, weight)- with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"): + with pytest.raises(ValueError, match=r"FP8 input.*BF16 weight.*not supported"): fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale, None)flashinfer/gemm/gemm_base.py (1)
3381-3385: The "90a" check is ineffective.The
_match_sm_versionfunction constructsdevice_archasf"{major * 10 + minor}", producing "90" for SM90/SM90a devices. The "90a" string in the list will never match since the 'a' suffix doesn't affect the reported compute capability.- if not _match_sm_version(input.device, ["90", "90a"]): + if not _match_sm_version(input.device, ["90"]):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 6bb01d1 and 710d0a7f91ab3f386e7c4b00f4256d67d2a50ac8.
📒 Files selected for processing (6)
csrc/fp8_blockscale_gemm_sm90_binding.cu(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(2 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/fp8_blockscale.py(1 hunks)tests/gemm/test_fp8_blockscale_gemm.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fp8_blockscale_gemm_sm90_binding.cu
🧬 Code graph analysis (4)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (1)
fp8_blockscale_gemm_sm90(3300-3542)
flashinfer/jit/gemm/fp8_blockscale.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/jit/cpp_ext.py (1)
is_cuda_version_at_least(86-87)
csrc/fp8_blockscale_gemm_sm90_binding.cu (2)
flashinfer/comm/cuda_ipc.py (1)
Function(37-40)csrc/tvm_ffi_utils.h (2)
encode_dlpack_dtype(30-32)get_stream(277-279)
flashinfer/gemm/gemm_base.py (2)
flashinfer/jit/gemm/fp8_blockscale.py (1)
gen_fp8_blockscale_gemm_sm90_module(10-54)csrc/fp8_blockscale_gemm_sm90_binding.cu (8)
init(211-214)init(211-211)input(92-174)input(92-93)input_is_fp8(78-90)input_is_fp8(78-79)workspace(189-196)workspace(189-189)
🪛 Ruff (0.14.8)
tests/gemm/test_fp8_blockscale_gemm.py
1-1: The file is executable but no shebang is present
(EXE002)
366-366: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
395-395: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
flashinfer/jit/gemm/fp8_blockscale.py
1-1: The file is executable but no shebang is present
(EXE002)
12-17: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
flashinfer/gemm/gemm_base.py
3383-3385: Avoid specifying long messages outside the exception class
(TRY003)
3389-3389: Avoid specifying long messages outside the exception class
(TRY003)
3391-3391: Avoid specifying long messages outside the exception class
(TRY003)
3397-3399: Avoid specifying long messages outside the exception class
(TRY003)
3404-3406: Avoid specifying long messages outside the exception class
(TRY003)
3416-3418: Avoid specifying long messages outside the exception class
(TRY003)
3423-3423: Avoid specifying long messages outside the exception class
(TRY003)
3428-3431: Avoid specifying long messages outside the exception class
(TRY003)
3433-3433: Avoid specifying long messages outside the exception class
(TRY003)
3435-3438: Avoid specifying long messages outside the exception class
(TRY003)
3441-3444: Avoid specifying long messages outside the exception class
(TRY003)
3446-3449: Avoid specifying long messages outside the exception class
(TRY003)
3453-3453: Avoid specifying long messages outside the exception class
(TRY003)
3460-3464: Avoid specifying long messages outside the exception class
(TRY003)
3466-3466: Avoid specifying long messages outside the exception class
(TRY003)
3469-3472: Avoid specifying long messages outside the exception class
(TRY003)
3474-3477: Avoid specifying long messages outside the exception class
(TRY003)
3482-3484: Avoid specifying long messages outside the exception class
(TRY003)
3486-3488: Avoid specifying long messages outside the exception class
(TRY003)
3490-3492: Avoid specifying long messages outside the exception class
(TRY003)
3498-3500: Avoid specifying long messages outside the exception class
(TRY003)
3533-3536: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (19)
csrc/fp8_blockscale_gemm_sm90_binding.cu (7)
1-26: LGTM!The includes and
is_fp8_e4m3fnhelper are well-structured. The conditional compilation withFLASHINFER_ENABLE_FP8_E4M3gracefully handles the case when FP8 support is disabled.
37-72: LGTM!The class structure is well-designed with clear separation of concerns. The
GetFunctiondispatch correctly exposes the three required interfaces, and the documentation comment clearly lists the supported dtype combinations.
74-90: LGTM!The runtime dtype dispatch correctly handles all four possible combinations with clear comments explaining why FP8 input + BF16 weight is unsupported.
113-140: LGTM!The scale validation logic is thorough and handles both per-token and per-block scale formats correctly. The error messages are descriptive with expected vs actual shapes, which aids debugging.
156-173: LGTM!The dual-path dispatch correctly handles W8A8 vs internal quantization paths with appropriate type casts. The comments explaining the two TRT-LLM method signatures are helpful for maintainability.
176-187: Consider wrapping workspace size queries in try-catch.Based on learnings from this codebase,
getWorkspaceSizeImplcalls can legitimately fail when probing configurations due to SMEM constraints. IfgetWorkspaceSizeBasecan throwstd::runtime_error, these calls should be wrapped in try-catch blocks.Verify if
getWorkspaceSizeBasecan throw and consider:int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) { size_t max_size = 0; - max_size = - std::max(max_size, runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); - max_size = - std::max(max_size, runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); - max_size = - std::max(max_size, runner_fp8_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + // Swallow errors when SMEM exceeds maximum allowed + try { + max_size = std::max(max_size, runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + } catch (std::runtime_error&) {} + try { + max_size = std::max(max_size, runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + } catch (std::runtime_error&) {} + try { + max_size = std::max(max_size, runner_fp8_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + } catch (std::runtime_error&) {} return max_size; }
198-216: LGTM!The member variables use
unique_ptrfor proper RAII memory management, and the module initialization correctly follows TVM FFI patterns.flashinfer/jit/gemm/__init__.py (1)
30-44: LGTM!The new import and
__all__entry follow the existing patterns in this file, correctly exposing the SM90 FP8 block-scale GEMM module generator.flashinfer/gemm/__init__.py (1)
16-34: LGTM!The new public API export follows the existing patterns consistently, correctly exposing
fp8_blockscale_gemm_sm90at the package level.flashinfer/jit/gemm/fp8_blockscale.py (1)
19-54: LGTM!The JIT spec is well-structured with comprehensive source files, include paths for TRT-LLM integration, and appropriate linker flags for CUDA runtime compilation support.
tests/gemm/test_fp8_blockscale_gemm.py (5)
48-57: LGTM!The JIT warmup fixture correctly handles conditional compilation based on hardware support and JIT cache availability, ensuring tests have a stable GPU execution environment.
60-104: LGTM!Well-structured parametrized test with appropriate hardware checks, clear assertions, and reasonable similarity thresholds for the BF16 internal quantization path.
106-189: LGTM with minor note.Good test coverage for the supported dtype combinations. The TODO at line 182 about checking the threshold is noted - consider addressing or tracking this before the feature is finalized.
308-343: LGTM!Good coverage of common LLM inference shapes with appropriate correctness validation.
399-430: LGTM!Good test coverage for pre-allocated output buffer functionality, correctly verifying both buffer identity and output correctness.
flashinfer/gemm/gemm_base.py (4)
59-59: LGTM!Import statement is correctly placed with other JIT gemm imports.
3293-3296: LGTM!The cached runner initialization follows the established pattern in this file and correctly chains the JIT spec build, load, and init calls.
3513-3538: LGTM!The input scale transformation correctly handles the layout conversion from user-provided
(M, K//128)format to TRT-LLM's expected(K//128, M)format with proper padding and stride verification. The stride check at line 3532-3536 ensures memory layout compatibility.
3503-3542: LGTM!The runner invocation correctly follows the pattern: workspace allocation, configuration, and GEMM execution. The argument order matches the C++
runGemmfunction signature.
794812c to
d4d19f7
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (3)
flashinfer/gemm/gemm_base.py (2)
59-60: Cache loader looks fine; include-dir wiring is reasonable.
One nit:_match_sm_versioncan’t ever match"90a"(it only builds"90","100", etc.), so the later arch check effectively ignores the"90a"string. Consider switching to an explicitis_sm90a_supported(...)check (consistent with tests) or just checking"90".Also applies to: 3293-3304
3525-3564: Clarify/validate the scale buffer layout and simplify the size math.
input_scale_sizeis computed via a non-obvious expression (Line 3529). If the intended layout is a flat(K_blocks * M_padded)buffer, this should be expressed directly (and ideally commented as “(K_blocks, M_padded) with ld=M_padded”). Also consider validatingweight_scale.device == input.device(you validate this forinput_scalebut not forweight_scale).@@ - K_blocks = (K + BLOCK_SIZE - 1) // BLOCK_SIZE - input_scale_size = ((K * M_padded * 4 + BLOCK_SIZE - 1) // BLOCK_SIZE) // 4 + K_blocks = K // BLOCK_SIZE + # Flat buffer representing (K_blocks, M_padded) with leading-dim M_padded. + input_scale_size = K_blocks * M_padded @@ if weight_is_fp8: @@ if weight_scale.dtype != torch.float32: raise ValueError(f"weight_scale must be float32, got {weight_scale.dtype}") + if weight_scale.device != weight.device: + raise ValueError( + f"weight_scale device mismatch. Expected {weight.device}, got {weight_scale.device}" + )csrc/fp8_blockscale_gemm_sm90_binding.cu (1)
206-217: Workspace-size probing should be resilient to per-runner shape failures.
IfgetWorkspaceSizeBase(...)can throw for certain shapes/configs (e.g., SMEM constraints), consider catching and swallowingstd::runtime_errorper the established pattern in this repo (“Swallow errors when SMEM exceeds maximum allowed”). Based on learnings, this is considered acceptable for probing multiple candidate kernels.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 65dd3fdbe65c0416b058e525b66bd66930ea76d0 and d4d19f727815db03af3ab5273d249ee3a0b42582.
📒 Files selected for processing (3)
csrc/fp8_blockscale_gemm_sm90_binding.cu(1 hunks)flashinfer/gemm/gemm_base.py(2 hunks)tests/gemm/test_fp8_blockscale_gemm.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.pycsrc/fp8_blockscale_gemm_sm90_binding.cu
🪛 Ruff (0.14.8)
flashinfer/gemm/gemm_base.py
3391-3393: Avoid specifying long messages outside the exception class
(TRY003)
3397-3397: Avoid specifying long messages outside the exception class
(TRY003)
3399-3399: Avoid specifying long messages outside the exception class
(TRY003)
3405-3407: Avoid specifying long messages outside the exception class
(TRY003)
3412-3414: Avoid specifying long messages outside the exception class
(TRY003)
3424-3426: Avoid specifying long messages outside the exception class
(TRY003)
3431-3431: Avoid specifying long messages outside the exception class
(TRY003)
3436-3439: Avoid specifying long messages outside the exception class
(TRY003)
3441-3441: Avoid specifying long messages outside the exception class
(TRY003)
3443-3446: Avoid specifying long messages outside the exception class
(TRY003)
3449-3452: Avoid specifying long messages outside the exception class
(TRY003)
3454-3457: Avoid specifying long messages outside the exception class
(TRY003)
3461-3461: Avoid specifying long messages outside the exception class
(TRY003)
3468-3472: Avoid specifying long messages outside the exception class
(TRY003)
3474-3474: Avoid specifying long messages outside the exception class
(TRY003)
3477-3480: Avoid specifying long messages outside the exception class
(TRY003)
3482-3485: Avoid specifying long messages outside the exception class
(TRY003)
3490-3492: Avoid specifying long messages outside the exception class
(TRY003)
3494-3496: Avoid specifying long messages outside the exception class
(TRY003)
3498-3500: Avoid specifying long messages outside the exception class
(TRY003)
3502-3504: Avoid specifying long messages outside the exception class
(TRY003)
3510-3512: Avoid specifying long messages outside the exception class
(TRY003)
3556-3559: Avoid specifying long messages outside the exception class
(TRY003)
tests/gemm/test_fp8_blockscale_gemm.py
1-1: The file is executable but no shebang is present
(EXE002)
425-425: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
454-454: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🔇 Additional comments (1)
flashinfer/gemm/gemm_base.py (1)
3515-3524: Unable to locate the specified file (flashinfer/gemm/gemm_base.py) in the current repository state. The review comment cannot be fully verified without access to:
- The Python file at the specified location (lines 3515-3524)
- The C++ GEMM runner implementation (get_fp8_blockscale_gemm_runner_sm90)
- The workspace configuration logic in C++
Critical discrepancy: The learning context indicates that in flashinfer GEMM implementations, it's acceptable to catch and silently ignore runtime_error exceptions when probing configurations due to SMEM constraints—suggesting that zero-workspace or failed-configuration scenarios may legitimately occur and be handled in the C++ layer rather than requiring Python-side guards like
max(int(workspace_size), 1).The review's core assumption (that unconditional configure_workspace calls with a minimum 1-byte allocation are necessary) conflicts with this documented pattern and requires verification against the actual C++ implementation behavior.
|
|
||
| #include <tvm/ffi/extra/module.h> | ||
|
|
||
| #include <cstddef> | ||
| #include <cstdint> | ||
| #include <functional> | ||
| #include <type_traits> | ||
| #include <vector> | ||
|
|
||
| #include "tensorrt_llm/common/cudaUtils.h" | ||
| #include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" | ||
| #include "tvm_ffi_utils.h" | ||
|
|
There was a problem hiding this comment.
Missing required standard headers (may fail to compile depending on transitive includes).
You use std::unique_ptr/std::make_unique (Lines 42-49, 228-237), std::max (Lines 209-215), and std::numeric_limits (Lines 192-193) without including <memory>, <algorithm>, <limits>. Please add them explicitly.
@@
#include <cstddef>
#include <cstdint>
+#include <algorithm>
#include <functional>
+#include <limits>
+#include <memory>
#include <type_traits>
#include <vector>Also applies to: 38-50, 206-217
🤖 Prompt for AI Agents
In csrc/fp8_blockscale_gemm_sm90_binding.cu around lines 1 to 13, the file uses
std::unique_ptr/std::make_unique, std::max, and std::numeric_limits in later
code (refs: ~lines 42-49, 192-215) but does not explicitly include the required
standard headers; add #include <memory> for unique_ptr/make_unique, #include
<algorithm> for std::max, and #include <limits> for std::numeric_limits at the
top of this file (alongside the existing includes) so compilation does not rely
on transitive includes.
| void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output, | ||
| const Optional<TensorView>& scales_a, const Optional<TensorView>& scales_b) { | ||
| auto stream = get_stream(input.device()); | ||
|
|
||
| auto input_ptr = input.data_ptr(); | ||
| auto weight_ptr = weight.data_ptr(); | ||
| auto output_ptr = output.data_ptr(); | ||
|
|
||
| int64_t shape_m = input.size(0); | ||
| int64_t shape_k = input.size(1); | ||
| int64_t shape_n = weight.size(0); | ||
|
|
||
| TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null"; | ||
| TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null"; | ||
| TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null"; | ||
| TVM_FFI_ICHECK(shape_k == weight.size(1)) << "K dimension mismatch"; | ||
| TVM_FFI_ICHECK(shape_k % 16 == 0) << "N must be a multiple of 16, (K=" << shape_k << ")"; | ||
| TVM_FFI_ICHECK(shape_n % 16 == 0) << "N must be a multiple of 16, (N=" << shape_n << ")"; | ||
|
|
||
| // Determine dtypes for runner selection | ||
| bool input_is_fp8 = is_fp8_e4m3fn(input.dtype()); | ||
| bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype()); | ||
|
|
||
| // Validate scale requirements | ||
| if (input_is_fp8) { | ||
| TVM_FFI_ICHECK(scales_a.has_value() && scales_a.value().data_ptr() != nullptr) | ||
| << "scales_a is required for FP8 input"; | ||
| // TensorRT-LLM expects scale shape: (K/128, M) after transpose | ||
| // int64_t expected_scale_k = (shape_k + 127) / 128; | ||
| // TVM_FFI_ICHECK(scales_a.value().size(0) == expected_scale_k && | ||
| // scales_a.value().size(1) == shape_m) | ||
| // << "scales_a shape mismatch: expected (" << expected_scale_k << ", " << shape_m | ||
| // << "), got (" << scales_a.value().size(0) << ", " << scales_a.value().size(1) << ")"; | ||
| } | ||
|
|
||
| if (weight_is_fp8) { | ||
| TVM_FFI_ICHECK(scales_b.has_value() && scales_b.value().data_ptr() != nullptr) | ||
| << "scales_b is required for FP8 weight"; | ||
| // Validate scale shape: should be (N, K/128) for per-token or (N/128, K/128) for per-block | ||
| int64_t expected_scale_k = (shape_k + 127) / 128; | ||
| int64_t scale_dim0 = scales_b.value().size(0); | ||
| int64_t scale_dim1 = scales_b.value().size(1); | ||
|
|
||
| bool is_per_token = (scale_dim0 == shape_n && scale_dim1 == expected_scale_k); | ||
| bool is_per_block = (scale_dim0 == (shape_n + 127) / 128 && scale_dim1 == expected_scale_k); | ||
|
|
||
| TVM_FFI_ICHECK(is_per_token || is_per_block) | ||
| << "scales_b shape mismatch: expected (" << shape_n << ", " << expected_scale_k | ||
| << ") for per-token or (" << ((shape_n + 127) / 128) << ", " << expected_scale_k | ||
| << ") for per-block, got (" << scale_dim0 << ", " << scale_dim1 << ")"; | ||
| } | ||
|
|
||
| // Extract scale pointers | ||
| float const* scales_a_ptr = scales_a.has_value() | ||
| ? reinterpret_cast<float const*>(scales_a.value().data_ptr()) | ||
| : nullptr; | ||
| float const* scales_b_ptr = scales_b.has_value() | ||
| ? reinterpret_cast<float const*>(scales_b.value().data_ptr()) | ||
| : nullptr; | ||
|
|
||
| // Select appropriate runner | ||
| auto* runner = selectRunner(input_is_fp8, weight_is_fp8); | ||
| TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination"; | ||
| TVM_FFI_ICHECK(workspace_ != nullptr) | ||
| << "Workspace not configured. Call configure_workspace first."; | ||
|
|
||
| // TensorRT-LLM has two gemm() methods: | ||
| // 1. gemm(void*, ...) - for internal quantization (BF16 inputs) | ||
| // 2. gemm(__nv_fp8_e4m3*, int, __nv_fp8_e4m3*, int, ...) - for pre-quantized FP8 inputs | ||
| if (input_is_fp8 && weight_is_fp8) { | ||
| // W8A8: Use the pre-quantized FP8 path | ||
| auto* fp8_input = reinterpret_cast<__nv_fp8_e4m3*>(input_ptr); | ||
| auto* fp8_weight = reinterpret_cast<__nv_fp8_e4m3*>(weight_ptr); | ||
| auto* bf16_output = reinterpret_cast<__nv_bfloat16*>(output_ptr); | ||
|
|
||
| runner->gemm(fp8_input, shape_k, // input with leading dimension | ||
| fp8_weight, shape_k, // weight with leading dimension | ||
| bf16_output, shape_n, // output with leading dimension | ||
| shape_m, shape_n, shape_k, scales_a_ptr, scales_b_ptr, stream); | ||
| } else { | ||
| // BF16+BF16 or BF16+FP8: Use internal quantization path | ||
| runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, stream, | ||
| scales_a_ptr, scales_b_ptr); | ||
| } | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find and read the actual file
find . -name "fp8_blockscale_gemm_sm90_binding.cu" -type fRepository: flashinfer-ai/flashinfer
Length of output: 110
🏁 Script executed:
#!/bin/bash
# Check the file structure
fd "fp8_blockscale_gemm_sm90_binding.cu" --type fRepository: flashinfer-ai/flashinfer
Length of output: 108
🏁 Script executed:
#!/bin/bash
# Search for the file in common locations
find . -path "*/csrc/*" -name "*fp8*blockscale*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 624
🏁 Script executed:
#!/bin/bash
# Broader search
git ls-files | grep -i "fp8.*blockscale\|blockscale.*binding"Repository: flashinfer-ai/flashinfer
Length of output: 913
🏁 Script executed:
#!/bin/bash
# List all files matching the pattern
git ls-files "**/*fp8_blockscale*" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 913
🏁 Script executed:
#!/bin/bash
# Find any references to this filename
rg "fp8_blockscale_gemm_sm90_binding" --type cuRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Find all .cu files in csrc
find ./csrc -name "*.cu" -type f 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer
Length of output: 667
🏁 Script executed:
#!/bin/bash
# Search for blockscale in cu files
rg "blockscale" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Check if csrc directory exists
ls -la csrc/ 2>/dev/null | head -20 || echo "csrc not found in root"
# Try alternative structure
find . -type d -name csrc 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 1515
🏁 Script executed:
#!/bin/bash
# Search for similar error message patterns and K%16 checks
rg "N must be a multiple of 16" --type cuRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Search for shape_k % 16 checks
rg "shape_k.*%.*16" --type cu -B2 -A2Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Search for output dtype validation patterns
rg "output.*dtype\|dtype.*output" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Search for bfloat16 or BF16 checks
rg "bfloat16|bf16" --type cu -i | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for TVM FFI utilities for dtype checking
find . -name "tvm_ffi_utils.h" -o -name "*tvm*utils*.h" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 90
🏁 Script executed:
#!/bin/bash
# Find other GEMM binding files
find . -type f -name "*gemm*binding*" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 311
🏁 Script executed:
#!/bin/bash
# Search in the referenced header file from learnings
cat include/flashinfer/gemm/bf16_gemm_cutlass_template.h 2>/dev/null | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Search all files that validate output dtype
rg "encode_dlpack_dtype|dtype.*code\|TVM_FFI_ICHECK.*dtype" --type cu -A2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for how dtype is validated in binding files
rg "dtype\(\)" --type cu -B2 -A2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Find files using TVM_FFI_ICHECK
rg "TVM_FFI_ICHECK" --type cu | wc -lRepository: flashinfer-ai/flashinfer
Length of output: 113
🏁 Script executed:
#!/bin/bash
# Get actual usage of TVM_FFI_ICHECK with dtype
rg "TVM_FFI_ICHECK.*dtype" --type cu -B1 -A1 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for blockscale references in Python files
rg "blockscale" --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2030
🏁 Script executed:
#!/bin/bash
# Search for any fp8_blockscale references across all file types
rg "fp8_blockscale" -i | head -30Repository: flashinfer-ai/flashinfer
Length of output: 3152
🏁 Script executed:
#!/bin/bash
# Show top-level directory structure
ls -la | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1257
🏁 Script executed:
#!/bin/bash
# Find all .cu files
find . -name "*.cu" -type f 2>/dev/null | head -30Repository: flashinfer-ai/flashinfer
Length of output: 970
🏁 Script executed:
#!/bin/bash
# Check git ls-files for cu files
git ls-files "*.cu" | grep -E "fp8|blockscale" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 782
🏁 Script executed:
#!/bin/bash
# All cu files in the repo
git ls-files "*.cu" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 910
🏁 Script executed:
#!/bin/bash
# Look for any mentions of shape_k or shape_n dimension checks
rg "shape_k.*shape_n|K.*multiple.*16|N.*multiple.*16" --type cu -B1 -A1 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Search for files with "binding" in the name
find . -name "*binding*" -type f 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 1184
🏁 Script executed:
#!/bin/bash
# Get complete cu file list
git ls-files | grep "\.cu$"Repository: flashinfer-ai/flashinfer
Length of output: 6582
🏁 Script executed:
#!/bin/bash
# Search for files containing the actual code from the review
rg "N must be a multiple of 16"Repository: flashinfer-ai/flashinfer
Length of output: 337
🏁 Script executed:
#!/bin/bash
# Search for shape_k % 16 with context
rg "shape_k.*%.*16" -A1 -B1Repository: flashinfer-ai/flashinfer
Length of output: 451
🏁 Script executed:
#!/bin/bash
# Search for runner templates and output types
rg "runner.*gemm|W8A8.*output" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for BF16 output validation
rg "__nv_bfloat16|output.*bf16" -i --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Verify the logic: check which dimension is checked with %16
cat > /tmp/verify_logic.py << 'EOF'
# The review states:
# Line 114 checks: shape_k % 16 == 0
# But message says: "N must be a multiple of 16, (K=" << shape_k << ")"
#
# This is indeed wrong - should be "K must be..."
# The check is on shape_k (the K dimension in matrix multiply)
# But the error message says "N"
print("Issue 1: Error message mismatch")
print("Code checks: shape_k % 16 == 0")
print("Message says: 'N must be a multiple of 16'")
print("Should say: 'K must be a multiple of 16'")
print("Status: CONFIRMED - this is a clear bug")
print()
print("Issue 2: Output BF16 validation")
print("Review claims runners are BF16-output only")
print("Code shows: auto* bf16_output = reinterpret_cast<__nv_bfloat16*>(output_ptr)")
print("This suggests BF16 output is assumed")
print("But no validation is done on output.dtype()")
print("If caller passes FP16 or other type, memory corruption could occur")
print("Status: NEEDS VERIFICATION - is this actually a safety issue?")
EOF
python /tmp/verify_logic.pyRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
#!/bin/bash
# Search for how other bindings handle output dtype validation
rg "output.*dtype\(\)" --type cu -B3 -A3 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for available dtype checking functions/macros
rg "encode_dlpack_dtype|is_fp8_e4m3fn|dtype.*code" -B1 -A1 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 4074
🏁 Script executed:
#!/bin/bash
# Check what dtype checking utilities exist
rg "bool.*dtype|is_.*dtype" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Search for dtype validation patterns in bindings
rg "TVM_FFI_ICHECK.*dtype\|is_fp8_e4m3fn" --type cu -B2 -A2 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look at is_fp8_e4m3fn usage
rg "is_fp8_e4m3fn" -B1 -A1Repository: flashinfer-ai/flashinfer
Length of output: 847
🏁 Script executed:
#!/bin/bash
# Search for bfloat16_code or similar dtype codes
rg "bfloat16_code|e4m3fn_code|dtype.*code" -B1 -A1 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 3614
🏁 Script executed:
/#search for how dtype is compared
rg "\.dtype()\|input\.dtype\(\)|weight\.dtype\(\)" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 236
🏁 Script executed:
#!/bin/bash
# Search for bfloat16 type checking
rg "is_bfloat16|bfloat16|bf16" -i --type cu | grep -E "is_|check|valid" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for how types are compared in existing code
rg "dtype().*==" --type cu -B1 -A1 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Check all cu files for dtype comparison patterns
rg "encode_dlpack_dtype|\.code\(\)" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for dtype constants or enums
rg "kDLFloat16|kDLBfloat" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 153
🏁 Script executed:
#!/bin/bash
# Get a comprehensive view of all dtype-related validations
rg "shape_[a-z] % 16\|N must be\|K must be" -B2 -A2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Verify the logic of the error message issue
python3 << 'EOF'
# Issue 1: Error Message on Line 114
# The code checks: TVM_FFI_ICHECK(shape_k % 16 == 0) << "N must be..."
# But it's checking shape_k (K dimension)
# Message says "N must be a multiple of 16"
# This is clearly wrong - message should say "K must be..."
# This is a straightforward factual error in the message
# Issue 2: Output dtype validation
# The code assumes output is BF16 (reinterpret_cast<__nv_bfloat16*>)
# But there's no validation that output.dtype() is actually BF16
# If someone passes FP16 or FP32, undefined behavior/memory corruption
# Both issues appear valid in principle
# Issue 1 is definitely a bug (wrong error message)
# Issue 2 is a type safety concern (missing validation)
print("Issue 1 (Error message): VALID BUG")
print("Issue 2 (Type validation): VALID CONCERN if runners are truly BF16-only")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
#!/bin/bash
# Check for dtype validation in similar files
rg "is_fp8_e4m3fn" --type cu -B5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for pattern: how to check if output is BF16
rg "bfloat16|bf16" --type cu -i -B2 -A2 | grep -E "TVM_FFI_ICHECK|if.*dtype|is_" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Search for any dtype comparison that might give us a pattern
rg "dtype\(\).*==" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
/#Check for type_code or similar
rg "type_code|code\(\)" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 148
🏁 Script executed:
#!/bin/bash
# Find function definitions for dtype checking
rg "^bool is_|^bool.*dtype" --type cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for is_fp8_e4m3fn definition or usage context
rg "is_fp8_e4m3fn" -B3 -A3 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1403
🏁 Script executed:
#!/bin/bash
# Final verification: Check if there's any existing dtype validation for output
rg "output\)\|TVM_FFI_ICHECK.*output" --type cu | grep dtype | head -10Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
#!/bin/bash
# Verify the K% 16 error message one more time with exact pattern
rg 'shape_k.*%.*16.*0.*<<.*"N must' --type cuRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
head -n 200 csrc/fp8_blockscale_gemm_sm90_binding.cuRepository: flashinfer-ai/flashinfer
Length of output: 8648
🏁 Script executed:
sed -n '1,50p' csrc/fp8_blockscale_gemm_sm90_binding.cuRepository: flashinfer-ai/flashinfer
Length of output: 1528
🏁 Script executed:
sed -n '110,125p' csrc/fp8_blockscale_gemm_sm90_binding.cuRepository: flashinfer-ai/flashinfer
Length of output: 937
🏁 Script executed:
cat csrc/tvm_ffi_utils.h | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4489
🏁 Script executed:
sed -n '98,182p' csrc/fp8_blockscale_gemm_sm90_binding.cuRepository: flashinfer-ai/flashinfer
Length of output: 4471
🏁 Script executed:
grep -n "encode_dlpack_dtype\|kDLBfloat\|bfloat16_code" csrc/fp8_blockscale_gemm_sm90_binding.cuRepository: flashinfer-ai/flashinfer
Length of output: 129
🏁 Script executed:
grep -n "encode_dlpack_dtype\|DLDataType\|kDL" csrc/tvm_ffi_utils.h | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2158
Fix error message on line 114 and add output BF16 dtype validation.
Line 114 checks shape_k % 16 == 0 (the K dimension) but the error message incorrectly says "N must be a multiple of 16"—it should say "K must be a multiple of 16" to match what is being checked.
Additionally, the code casts output_ptr to __nv_bfloat16* without validating that output.dtype() is actually BF16. If a caller passes FP16 or another type, this will cause memory corruption. Use the bfloat16_code constant already available in tvm_ffi_utils.h (and the encode_dlpack_dtype() pattern already used in this file) to validate the output dtype.
- TVM_FFI_ICHECK(shape_k % 16 == 0) << "N must be a multiple of 16, (K=" << shape_k << ")";
+ TVM_FFI_ICHECK(shape_k % 16 == 0) << "K must be a multiple of 16 (K=" << shape_k << ")";
TVM_FFI_ICHECK(shape_n % 16 == 0) << "N must be a multiple of 16, (N=" << shape_n << ")";
+
+ // Output must be BF16 for these runner instantiations.
+ TVM_FFI_ICHECK(encode_dlpack_dtype(output.dtype()) == bfloat16_code)
+ << "output must be BF16 (bfloat16) for fp8_blockscale_gemm_sm90";🤖 Prompt for AI Agents
In csrc/fp8_blockscale_gemm_sm90_binding.cu around lines 98 to 182, change the
incorrect error text on the shape_k modulus check to mention "K must be a
multiple of 16" instead of "N", and add a defensive dtype check before casting
output_ptr to __nv_bfloat16*: use the existing
encode_dlpack_dtype(output.dtype()) comparison against the bfloat16_code
constant (from tvm_ffi_utils.h) and assert with TVM_FFI_ICHECK that the output
dtype is BF16; perform this validation prior to any reinterpret_cast to
__nv_bfloat16* so we avoid potential memory corruption.
|
|
||
| # Validate dtype combinations | ||
| input_is_fp8 = input.dtype == torch.float8_e4m3fn | ||
| weight_is_fp8 = weight.dtype == torch.float8_e4m3fn | ||
| input_is_bf16 = input.dtype == torch.bfloat16 | ||
| weight_is_bf16 = weight.dtype == torch.bfloat16 | ||
|
|
||
| # Explicitly reject FP8 input + BF16 weight (missing kernel implementation) | ||
| if input_is_fp8 and weight_is_bf16: | ||
| raise ValueError( | ||
| "FP8 input + BF16 weight is not supported (missing kernel implementation). " | ||
| ) | ||
|
|
||
| # Validate scale requirements for FP8 inputs | ||
| if input_is_fp8: | ||
| if input_scale is None: | ||
| raise ValueError("input_scale is required when input is FP8. ") | ||
| # Users provide input_scale in shape (M, K//128), matching per_token_cast_to_fp8 output. | ||
| # We transpose it internally to (K//128, M) to match TensorRT-LLM kernel expectations. | ||
| expected_scale_shape = (M, K // BLOCK_SIZE) | ||
| if input_scale.shape != expected_scale_shape: | ||
| raise ValueError( | ||
| f"input_scale shape mismatch. Expected {expected_scale_shape}, " | ||
| f"got {input_scale.shape}" | ||
| ) | ||
| if input_scale.dtype != torch.float32: | ||
| raise ValueError(f"input_scale must be float32, got {input_scale.dtype}") | ||
| if input_scale.device != input.device: | ||
| raise ValueError( | ||
| f"input_scale device mismatch. Expected {input.device}, " | ||
| f"got {input_scale.device}" | ||
| ) | ||
| else: | ||
| if not input_is_bf16: | ||
| raise ValueError( | ||
| f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " | ||
| f"got {input.dtype}" | ||
| ) | ||
| if input_scale is not None: | ||
| raise ValueError( | ||
| "input_scale should not be provided for BF16 inputs. " | ||
| "Use FP8 inputs if you want to provide external scales." | ||
| ) | ||
|
|
||
| if weight_is_fp8: | ||
| if weight_scale is None: | ||
| raise ValueError("weight_scale is required when weight is FP8. ") | ||
| expected_per_token_shape = (N, K // BLOCK_SIZE) | ||
| expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) | ||
| is_per_token = weight_scale.shape == expected_per_token_shape | ||
| is_per_block = weight_scale.shape == expected_per_block_shape | ||
|
|
||
| if not (is_per_token or is_per_block): | ||
| raise ValueError( | ||
| f"weight_scale shape mismatch. Expected either {expected_per_token_shape} " | ||
| f"(per-token, 1x128 blocks) or {expected_per_block_shape} " | ||
| f"(per-block, 128x128 blocks), got {weight_scale.shape}" | ||
| ) | ||
| if weight_scale.dtype != torch.float32: | ||
| raise ValueError(f"weight_scale must be float32, got {weight_scale.dtype}") | ||
| else: | ||
| if not weight_is_bf16: | ||
| raise ValueError( | ||
| f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " | ||
| f"got {weight.dtype}" | ||
| ) | ||
| if weight_scale is not None: | ||
| raise ValueError( | ||
| "weight_scale should not be provided for BF16 weights. " | ||
| "Use FP8 weights if you want to provide external scales." | ||
| ) | ||
|
|
||
| # Validate output tensor if provided | ||
| if out is not None: | ||
| if out.shape != (M, N): | ||
| raise ValueError( | ||
| f"Output shape mismatch. Expected ({M}, {N}), got {out.shape}" | ||
| ) | ||
| if out.device != input.device: | ||
| raise ValueError( | ||
| f"Output device mismatch. Expected {input.device}, got {out.device}" | ||
| ) | ||
| if out.dtype not in [torch.bfloat16, torch.float16]: | ||
| raise ValueError( | ||
| f"Output dtype must be torch.bfloat16 or torch.float16, got {out.dtype}" | ||
| ) | ||
| if out_dtype is not None and out.dtype != out_dtype: | ||
| raise ValueError( | ||
| f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}" | ||
| ) | ||
| out_dtype = out.dtype | ||
| else: | ||
| # Allocate output | ||
| out_dtype = out_dtype or torch.bfloat16 | ||
| if out_dtype not in [torch.bfloat16, torch.float16]: | ||
| raise ValueError( | ||
| f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}" | ||
| ) | ||
| out = torch.empty(M, N, dtype=out_dtype, device=input.device) | ||
|
|
There was a problem hiding this comment.
Output dtype contract is inconsistent with the C++ binding (likely BF16-only).
The C++ runner instantiations use __nv_bfloat16 output and the binding comment says “Output is always BF16”, but Python accepts/allocates FP16 outputs (Lines 3497-3513). That can lead to writing BF16 bits into an FP16 tensor (silent corruption).
@@
- if out is not None:
+ if out is not None:
@@
- if out.dtype not in [torch.bfloat16, torch.float16]:
+ if out.dtype not in [torch.bfloat16]:
raise ValueError(
- f"Output dtype must be torch.bfloat16 or torch.float16, got {out.dtype}"
+ f"Output dtype must be torch.bfloat16, got {out.dtype}"
)
@@
- else:
+ else:
# Allocate output
out_dtype = out_dtype or torch.bfloat16
- if out_dtype not in [torch.bfloat16, torch.float16]:
+ if out_dtype not in [torch.bfloat16]:
raise ValueError(
- f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}"
+ f"Output dtype must be torch.bfloat16, got {out_dtype}"
)If FP16 output is actually desired, the fix should be on the C++ side (add an FP16-output runner + dtype dispatch) rather than allowing it only in Python.
🧰 Tools
🪛 Ruff (0.14.8)
3391-3393: Avoid specifying long messages outside the exception class
(TRY003)
3397-3397: Avoid specifying long messages outside the exception class
(TRY003)
3399-3399: Avoid specifying long messages outside the exception class
(TRY003)
3405-3407: Avoid specifying long messages outside the exception class
(TRY003)
3412-3414: Avoid specifying long messages outside the exception class
(TRY003)
3424-3426: Avoid specifying long messages outside the exception class
(TRY003)
3431-3431: Avoid specifying long messages outside the exception class
(TRY003)
3436-3439: Avoid specifying long messages outside the exception class
(TRY003)
3441-3441: Avoid specifying long messages outside the exception class
(TRY003)
3443-3446: Avoid specifying long messages outside the exception class
(TRY003)
3449-3452: Avoid specifying long messages outside the exception class
(TRY003)
3454-3457: Avoid specifying long messages outside the exception class
(TRY003)
3461-3461: Avoid specifying long messages outside the exception class
(TRY003)
3468-3472: Avoid specifying long messages outside the exception class
(TRY003)
3474-3474: Avoid specifying long messages outside the exception class
(TRY003)
3477-3480: Avoid specifying long messages outside the exception class
(TRY003)
3482-3485: Avoid specifying long messages outside the exception class
(TRY003)
3490-3492: Avoid specifying long messages outside the exception class
(TRY003)
3494-3496: Avoid specifying long messages outside the exception class
(TRY003)
3498-3500: Avoid specifying long messages outside the exception class
(TRY003)
3502-3504: Avoid specifying long messages outside the exception class
(TRY003)
3510-3512: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
flashinfer/gemm/gemm_base.py around lines 3497-3513: the Python API currently
accepts/allocates torch.float16 outputs even though the C++ runner/binding is
BF16-only, risking silent corruption. Fix by restricting outputs to
torch.bfloat16: validate that provided out.dtype (and out_dtype when supplied)
is torch.bfloat16 and raise an error otherwise; when allocating out, always
allocate with dtype=torch.bfloat16 (ignore or reject requests for float16). If
FP16 outputs are actually required, implement the FP16 runner and dtype dispatch
in the C++ bindings instead of allowing FP16 only in Python.
| @pytest.fixture( | ||
| autouse=not has_flashinfer_jit_cache(), | ||
| scope="module", | ||
| ) | ||
| def warmup_jit(): | ||
| """Warm up JIT compilation for FP8 block-scale GEMM if not cached.""" | ||
| if is_sm90a_supported(torch.device("cuda:0")): | ||
| jit_specs = [gen_fp8_blockscale_gemm_sm90_module()] | ||
| flashinfer.jit.build_jit_specs(jit_specs, verbose=False) | ||
| yield | ||
|
|
There was a problem hiding this comment.
Warmup fixture should guard CUDA availability to avoid import-time failures.
Because it’s autouse, it may execute before individual tests can skip(). Add a torch.cuda.is_available() guard (or skip) before touching cuda:0.
@@
def warmup_jit():
"""Warm up JIT compilation for FP8 block-scale GEMM if not cached."""
+ if not torch.cuda.is_available():
+ yield
+ return
if is_sm90a_supported(torch.device("cuda:0")):
jit_specs = [gen_fp8_blockscale_gemm_sm90_module()]
flashinfer.jit.build_jit_specs(jit_specs, verbose=False)
yield📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @pytest.fixture( | |
| autouse=not has_flashinfer_jit_cache(), | |
| scope="module", | |
| ) | |
| def warmup_jit(): | |
| """Warm up JIT compilation for FP8 block-scale GEMM if not cached.""" | |
| if is_sm90a_supported(torch.device("cuda:0")): | |
| jit_specs = [gen_fp8_blockscale_gemm_sm90_module()] | |
| flashinfer.jit.build_jit_specs(jit_specs, verbose=False) | |
| yield | |
| @pytest.fixture( | |
| autouse=not has_flashinfer_jit_cache(), | |
| scope="module", | |
| ) | |
| def warmup_jit(): | |
| """Warm up JIT compilation for FP8 block-scale GEMM if not cached.""" | |
| if not torch.cuda.is_available(): | |
| yield | |
| return | |
| if is_sm90a_supported(torch.device("cuda:0")): | |
| jit_specs = [gen_fp8_blockscale_gemm_sm90_module()] | |
| flashinfer.jit.build_jit_specs(jit_specs, verbose=False) | |
| yield |
🤖 Prompt for AI Agents
In tests/gemm/test_fp8_blockscale_gemm.py around lines 48 to 58, the autouse
fixture warmup_jit can touch torch.device("cuda:0") during import even when CUDA
is unavailable; update the fixture to first check torch.cuda.is_available() (or
call pytest.skip) and return/skip early if CUDA is not available, then proceed
to call is_sm90a_supported and build the JIT specs only when CUDA is present to
avoid import-time failures.
| def test_fp8_blockscale_gemm_error_handling(): | ||
| """Test that proper errors are raised for invalid inputs.""" | ||
| compute_capability = get_compute_capability(torch.device("cuda")) | ||
| if compute_capability[0] < 9: | ||
| pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") | ||
|
|
||
| if not is_sm90a_supported(torch.device("cuda")): | ||
| pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") | ||
|
|
||
| device = "cuda" | ||
| m, n, k = 16, 256, 256 | ||
|
|
||
| # Test: K not divisible by 128 | ||
| input = torch.randn(m, 127, device=device, dtype=torch.bfloat16) | ||
| weight = torch.randn(n, 127, device=device, dtype=torch.bfloat16) | ||
| with pytest.raises(ValueError, match="divisible by block size"): | ||
| fp8_blockscale_gemm_sm90(input, weight) | ||
|
|
||
| # Test: FP16 not supported | ||
| input = torch.randn(m, k, device=device, dtype=torch.float16) | ||
| weight = torch.randn(n, k, device=device, dtype=torch.float16) | ||
| with pytest.raises(ValueError, match="FP8.*or BF16"): | ||
| fp8_blockscale_gemm_sm90(input, weight) | ||
|
|
||
| # Test: FP8 weight without scale (naive conversion) | ||
| input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) | ||
| weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) | ||
| weight_fp8_naive = weight_bf16.to(torch.float8_e4m3fn) | ||
| with pytest.raises(ValueError, match="weight_scale is required when weight is FP8"): | ||
| fp8_blockscale_gemm_sm90(input_bf16, weight_fp8_naive, None, None) | ||
|
|
||
| # Test: BF16 input with scale (should raise error) | ||
| input = torch.randn(m, k, device=device, dtype=torch.bfloat16) | ||
| weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) | ||
| fake_scale = torch.ones(m, k // 128, device=device, dtype=torch.float32) | ||
| with pytest.raises(ValueError, match="input_scale should not be provided for BF16"): | ||
| fp8_blockscale_gemm_sm90(input, weight, input_scale=fake_scale) | ||
|
|
||
| # Test: Wrong scale shape for FP8 input | ||
| input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) | ||
| input_fp8, _ = per_token_cast_to_fp8(input_bf16) | ||
| weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) | ||
| wrong_scale = torch.ones(m, k // 64, device=device, dtype=torch.float32) | ||
| with pytest.raises(ValueError): | ||
| fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale=wrong_scale) | ||
|
|
||
| # Test: FP8 input + BF16 weight is NOT supported | ||
| input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) | ||
| input_fp8, input_scale = per_token_cast_to_fp8(input_bf16) | ||
| weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) | ||
| with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"): | ||
| fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale, None) | ||
|
|
There was a problem hiding this comment.
Use raw strings / escape regex metacharacters in pytest.raises(..., match=...).
Ruff’s RUF043 is valid here: patterns like "FP8.*or BF16" should be raw strings to avoid unintended escapes / regex surprises.
@@
- with pytest.raises(ValueError, match="FP8.*or BF16"):
+ with pytest.raises(ValueError, match=r"FP8.*or BF16"):
fp8_blockscale_gemm_sm90(input, weight)
@@
- with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"):
+ with pytest.raises(ValueError, match=r"FP8 input.*BF16 weight.*not supported"):
fp8_blockscale_gemm_sm90(input_fp8, weight, input_scale, None)🧰 Tools
🪛 Ruff (0.14.8)
425-425: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
454-454: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🤖 Prompt for AI Agents
In tests/gemm/test_fp8_blockscale_gemm.py around lines 404 to 456, the regex
strings passed to pytest.raises(..., match=...) are regular Python strings and
should be raw strings to avoid accidental escapes or unintended regex behavior;
update the match arguments to use raw string literals (prefix with r'...') for
patterns such as "FP8.*or BF16" and "FP8 input.*BF16 weight.*not supported" (and
any other match patterns that include backslashes or regex metacharacters), or
alternatively escape literal regex metacharacters if you intend a literal match.
|
/bot run |
|
[CANCELING] Pipeline #40112860: canceled |
| << "scales_a is required for FP8 input"; | ||
| // TensorRT-LLM expects scale shape: (K/128, M) after transpose | ||
| // int64_t expected_scale_k = (shape_k + 127) / 128; | ||
| // TVM_FFI_ICHECK(scales_a.value().size(0) == expected_scale_k && |
There was a problem hiding this comment.
Why do we comment these out?
There was a problem hiding this comment.
Hi Zihao, this was added before porting the quant kernel but then relaxed due to TRTLLM internally is expecting a 1-d tensor. Thanks!
| Returns diff = 1 - sim, where sim = 2*<x,y> / (||x||² + ||y||²) | ||
| This is similar to cosine similarity but uses squared norms in denominator. | ||
|
|
||
| diff < 0.001 corresponds to >99.9% similarity. |
There was a problem hiding this comment.
Is this mechanism designed to be aligned with trtllm?
There was a problem hiding this comment.
Oops, I think this was added by previous developer, but was never used.
I changed the criteria to cosine similarity same as VLLM and passed all test with 99.9% similarity (both VLLM and FlashInfer).
| @@ -1163,10 +1163,13 @@ __global__ void convert_kernel(OutputType* output, InputType const* const input, | |||
| } | |||
|
|
|||
| static int kNumDeviceSMs = -1; | |||
| static bool kDeepGemmEnabled = []() -> bool { | |||
| // a function that returns kDeepGemmEnabled | |||
| bool getDeepGemmEnabled() { | |||
There was a problem hiding this comment.
Thanks for making this change, it make sense to me.
For more background, @katec846 found that we will skip deepgemm silently after #2090 because kDeepGemmEnabled returns false, in the future I think we should display these informations in the logging (e.g. whether deep gemm is enabled or not).
cc @bkryu
5d74d09 to
5e5d685
Compare
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
5e5d685 to
72c4f7b
Compare
72c4f7b to
a4ae22c
Compare
|
Hi @yzh119 thanks for reviewing. This PR is considered ready now. I found and noted in vllm-project/vllm#29213 (comment) that with my implementation DS-R1 has an accuracy degradation. I've cross-referenced this PR and TRTLLM's implementation, as well as VLLM's FP8 Linear op unit test, there seemed to be no error at all in this FlashInfer PR. There's some part in DS-R1 that's having accuracy diff in VLLM. Thank you again!! |
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
a4ae22c to
c2e0540
Compare
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Thanks for your contribution! The failed UTs are not relevant and let's merge it now.
<!-- .github/pull_request_template.md --> ## 📌 Description This PR is adding W8A8 on top of [PR2101](flashinfer-ai#2101) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * FP8 block-scale GEMM for SM90 (Hopper): supports BF16/FP8 input combinations, per-token/per-block scales, runtime-dtype dispatch, workspace configuration, and robust input/output validation. * **Chores** * Exposed SM90 FP8 GEMM runner in package APIs and added a JIT module generator to build/load the SM90 kernel. * **Tests** * Added comprehensive tests for shapes, dtypes, scales, preallocated outputs, correctness thresholds, hardware guards, and error handling. * **Refactor** * DeepGemm enablement now evaluated at runtime for dynamic behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com> Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Co-authored-by: Vivian Chen <140748220+xuanzic@users.noreply.github.com> Co-authored-by: Vivian Chen <xuanzic@nvidia.com> Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
📌 Description
This PR is adding W8A8 on top of PR2101
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.