Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughVersion number incremented from 0.6.9 to 0.6.10 in ChangesVersion Bump
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~1 minute Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the version from 0.6.9 to 0.6.10. Feedback suggests that a patch version bump is inappropriate because the PR contains breaking changes; a minor version bump or the use of deprecation aliases is recommended to adhere to semantic versioning.
| @@ -1 +1 @@ | |||
| 0.6.9 | |||
| 0.6.10 | |||
There was a problem hiding this comment.
The pull request description mentions a breaking change (renaming Sm120BlockScaledDenseGemmKernel to Sm120B12xBlockScaledDenseGemmKernel and removing the old name). According to semantic versioning principles, a patch release (0.6.9 to 0.6.10) should not introduce breaking changes, as this can disrupt downstream users who rely on patch updates for bug fixes without API breakage. Consider either adding a deprecation alias for the old class name in flashinfer/gemm/__init__.py to maintain backward compatibility, or bumping the version to a minor version (e.g., 0.7.0) instead of a patch version.
Code Review: v0.6.10 Version BumpOverall: This is a clean release-aggregation PR — the only direct diff is the one-line version bump in Issues / Items Requiring Decision1. Potentially Breaking Module Rename (flagged in PR body, but needs resolution before release)The PR body correctly flags this, but it's worth escalating:
Since only the new name is re-exported in Recommendation: Add a deprecation alias in the old location for at least one minor version, or explicitly document this as a breaking change in the release notes: # dense_blockscaled_gemm_sm120.py (kept for compat)
from .dense_blockscaled_gemm_sm120_b12x import Sm120B12xBlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel
import warnings
warnings.warn("Import from dense_blockscaled_gemm_sm120 is deprecated, use dense_blockscaled_gemm_sm120_b12x", DeprecationWarning)2. Internal Bucket Strategy Rename May Affect Downstream Extensions
Positive Observations
|
Description
Bump version to 0.6.10 for release.
Related Issues (Gated-by PRs)
https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10
Reviewer Notes
API changes review
API changes since v0.6.9
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api" register_custom_op, @@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: ) -@flashinfer_api +@flashinfer_api(trace=silu_and_mul_trace) def silu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -112,7 +118,7 @@ def silu_and_mul( return out -@flashinfer_api +@flashinfer_api(trace=gelu_tanh_and_mul_trace) def gelu_tanh_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -153,7 +159,7 @@ def gelu_tanh_and_mul( return out -@flashinfer_api +@flashinfer_api(trace=gelu_and_mul_trace) def gelu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -194,7 +200,7 @@ def gelu_and_mul( return out -@flashinfer_api +@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace) def silu_and_mul_scaled_nvfp4_experts_quantize( a, mask, diff --git a/flashinfer/aot.py b/flashinfer/aot.py index dfb05150..d26d5407 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -543,6 +543,7 @@ def gen_all_modules( if add_comm: from .jit.comm import ( gen_comm_alltoall_module, + gen_dcp_alltoall_module, gen_moe_alltoall_module, gen_trtllm_comm_module, gen_trtllm_mnnvl_comm_module, @@ -554,6 +555,11 @@ def gen_all_modules( jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_trtllm_mnnvl_comm_module()) jit_specs.append(gen_moe_alltoall_module()) + # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has -- -def flashinfer_api(func: Callable = None) -> Callable: +# --------------------------------------------------------------------------- +# Trace template registry +# --------------------------------------------------------------------------- +# Populated automatically by _attach_fi_trace whenever @flashinfer_api is +# given a trace= argument. Each entry is (original_func, template, label) +# where label is the template's name_prefix (or op_type as fallback). +# +# For dispatch callables (trace=some_fn), every template listed in +# some_fn.templates is registered if that attribute exists. +# +# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover +# all registered templates without requiring manual maintenance. +_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = [] + + +def _attach_fi_trace( + wrapped: Callable, + original: Callable, + trace_template=None, +) -> Callable: + """Attach a ``fi_trace`` callable to *wrapped*. + + Three resolution strategies, tried in order: + -- + + warnings.warn( + f"[flashinfer] Failed to attach fi_trace to '{_func_name}': " + f"{type(_exc).__name__}: {_exc}\n" + f"The function will work normally but fi_trace will be unavailable. " + f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).", + stacklevel=3, + ) + return wrapped + + +def flashinfer_api(func: Callable = None, *, trace=None) -> Callable: """ Decorator to FlashInfer's APIs. @@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable: - The %i pattern is automatically replaced with the process ID for multi-process environments. - The logger does not propagate to the root logger to avoid duplicate logs. """ - # If logging is disabled, return original function with zero overhead + # If logging is disabled, return original function with zero overhead. + # We still attach fi_trace so it is always available regardless of log level. if _API_LOG_LEVEL == 0: if func is None: - return lambda f: f - return func -- @functools.cache @@ -135,7 +136,7 @@ class BatchAttention: causal, ) - @flashinfer_api + @flashinfer_api(trace=batch_attention_run_trace) def run( self, q: torch.Tensor, @@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper a convenient interface for using attention sinks during prefill or decode attention. """ + # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper + # already decorates __init__, so decorating again produces double log entries. def __init__( self, float_workspace_buffer: torch.Tensor, diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py new file mode 100644 index 00000000..3e029627 --- /dev/null +++ b/flashinfer/attention/cute_dsl/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, -- @@ -31,7 +37,7 @@ def get_cascade_module(): return gen_cascade_module().build_and_load() -@flashinfer_api +@flashinfer_api(trace=merge_state_trace) @register_custom_op("flashinfer::merge_state", mutates_args=()) def merge_state( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor @@ -98,7 +104,7 @@ def _fake_merge_state( return v, s -@flashinfer_api +@flashinfer_api(trace=merge_state_in_place_trace) @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s")) def merge_state_in_place( v: torch.Tensor, @@ -159,7 +165,7 @@ def _fake_merge_state_in_place( pass -@flashinfer_api +@flashinfer_api(trace=merge_states_trace) @register_custom_op("flashinfer::merge_states", mutates_args=()) def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Merge multiple attention states (v, s). @@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper: begin_forward = plan - @flashinfer_api + @flashinfer_api(trace=multi_level_cascade_run_trace) def run( self, q: torch.Tensor, diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 5f186002..31d23a99 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import ( moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace, ) +# DCP A2A (Decode Context Parallel Attention Reduction) +from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall +from .dcp_alltoall import ( + decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace, +) +from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace +from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size + # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo -- from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion @@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace( # ============================================================================ -@flashinfer_api +@flashinfer_api(trace=allreduce_fusion_trace) def allreduce_fusion( input: torch.Tensor, workspace: AllReduceFusionWorkspace, diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py new file mode 100644 index 00000000..3047f76c --- /dev/null +++ b/flashinfer/comm/dcp_alltoall.py @@ -0,0 +1,255 @@ +""" +DCP All-to-All Operations for DCP Attention Reduction + +Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel +attention reduction. Uses SM90+ features (TMA, mbarrier). + +Usage protocol:: + + # 1. Query workspace size + ws_bytes = decode_cp_a2a_workspace_size(cp_size) + -- + + +# ─── Public API ─────────────────────────────────────────────────────────── + + +@flashinfer_api +def decode_cp_a2a_workspace_size(cp_size: int) -> int: + """Return the workspace size **in bytes** per rank for the given CP group size. + + Args: + cp_size: Context-parallel group size (number of ranks). + + Returns: + Workspace size in bytes per rank. + + Example:: + + >>> decode_cp_a2a_workspace_size(4) + 16778240 + """ + return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size) + + +@flashinfer_api +def decode_cp_a2a_allocate_workspace( + cp_size: int, + cp_rank: int, + *, + mapping: Optional[Mapping] = None, + mnnvl_config: Optional[MnnvlConfig] = None, +) -> torch.Tensor: + """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``. + + After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a + cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call. + + Two allocation modes: + + - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via + FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks + cannot see each other's device memory directly. + - **Plain device memory** (``mapping=None``): Standard ``torch.zeros`` + allocation. Sufficient for single-node with NVLink P2P. + -- + + ws_elems_per_rank = (ws_bytes + 7) // 8 + return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda") + + +@flashinfer_api +def decode_cp_a2a_init_workspace( + workspace: torch.Tensor, + cp_rank: int, + cp_size: int, +) -> None: + """Initialize the workspace FIFO buffers. Call once before the first alltoall. + + Resets the FIFO buffers in the **local** workspace row + (``workspace[cp_rank]``). This function is **synchronous**: when it + returns, the GPU memset is guaranteed to have completed. + + .. important:: + With MNNVL workspaces, **all ranks** must complete + ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier + (e.g. ``dist.barrier(group)``) before **any** rank calls + :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may + start writing to a peer's FIFO before that peer has finished + initializing → deadlock. + + Args: -- + # subsequent cross-GPU alltoall can race with the unfinished memset + # on MNNVL memory, causing a deadlock. + torch.cuda.current_stream().synchronize() + + +@flashinfer_api(trace=decode_cp_a2a_alltoall_trace) +def decode_cp_a2a_alltoall( + partial_o: torch.Tensor, + softmax_stats: torch.Tensor, + workspace: torch.Tensor, + cp_rank: int, + cp_size: int, + enable_pdl: Optional[bool] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Perform the DCP all-to-all exchange. + + Each rank sends its ``partial_o[..., peer, :]`` slice to the + corresponding peer and receives all peers' contributions into the + output tensors. + + Args: + partial_o: ``[..., cp_size, D]`` — half or bfloat16. + ``D * element_size`` must be 16-byte aligned. + softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even. + Batch dimensions must match ``partial_o``. + workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from -- + MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather, + MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce, +} + + +@flashinfer_api +@backend_requirement( + backend_checks={}, + common_check=_common_check, +) +def run_mixed_comm( + op: MixedCommOp, + handler: MixedCommHandler, + x_in: torch.Tensor, + x_out: torch.Tensor | None = None, + mode: MixedCommMode | None = None, +) -> torch.Tensor: + """Execute a mixed communication operation. + + This is the main entry point for running communication collectives + through the mixed communication handler. It supports fused GPU kernels + (using virtual memory intra-node and nvshmem inter-node), NCCL-based + fallbacks, and autotuned mode selection. + + Args: + op: The communication operation to perform. -- @functools.cache @@ -28,7 +29,7 @@ def get_concat_mla_module(): return gen_concat_mla_module().build_and_load() -@flashinfer_api +@flashinfer_api(trace=concat_mla_k_trace) def concat_mla_k( k: torch.Tensor, k_nope: torch.Tensor, diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 195ca2d4..9b593095 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -4,6 +4,7 @@ from typing import Optional import torch from ..api_logging import flashinfer_api +from ..trace.templates.attention import cudnn_batch_decode_trace from .utils import get_cudnn_fmha_gen_module try: @@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache( return out -@flashinfer_api +@flashinfer_api(trace=cudnn_batch_decode_trace) def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc1bbb5f..b16d6043 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -4,6 +4,7 @@ from typing import Optional import torch from ..api_logging import flashinfer_api +from ..trace.templates.attention import cudnn_batch_prefill_trace from .utils import get_cudnn_fmha_gen_module try: @@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache( return out, None -@flashinfer_api +@flashinfer_api(trace=cudnn_batch_prefill_trace) def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py index 0b50c22c..f25aa6fd 100644 --- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py @@ -38,6 +38,7 @@ import torch from cutlass import Float32, Int32, Int64, Uint32, Uint8 from ..api_logging import flashinfer_api +from ..trace.templates.norm import add_rmsnorm_fp4quant_trace from ..utils import device_support_pdl from .fp4_common import ( # Constants @@ -1042,7 +1043,7 @@ def _get_compiled_kernel( return tensor_api -@flashinfer_api +@flashinfer_api(trace=add_rmsnorm_fp4quant_trace) def add_rmsnorm_fp4quant( input: torch.Tensor, residual: torch.Tensor, diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index 333697ab..b7aabc36 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -20,6 +20,7 @@ import torch from cutlass import Float32, Int32 from flashinfer.api_logging import flashinfer_api +from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace from flashinfer.utils import device_support_pdl from flashinfer.cute_dsl.utils import ( get_max_active_clusters, @@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper: f"out_dtype={self._o_dtype}" ) - @flashinfer_api + @flashinfer_api(trace=cute_dsl_batch_mla_run_trace) def run( self, q: torch.Tensor, diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py index 58a24abe..ee0cd5e7 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py @@ -21,6 +21,7 @@ import cutlass.cute as cute from cutlass.cute.typing import Int32 from flashinfer.api_logging import flashinfer_api +from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace from ..config import AttentionConfig, AttentionFusion from ..fusion.mask import MaskType @@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper: f"device={self._device}" ) - @flashinfer_api + @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace) def run( self, q: torch.Tensor, diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py index bc4acffc..97ce68a1 100644 --- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py @@ -32,6 +32,7 @@ import torch from cutlass import Float32, Int32, Uint8 from ..api_logging import flashinfer_api +from ..trace.templates.norm import rmsnorm_fp4quant_trace from ..utils import device_support_pdl from .fp4_common import ( # Constants @@ -771,7 +772,7 @@ def _get_compiled_kernel( return tensor_api -@flashinfer_api +@flashinfer_api(trace=rmsnorm_fp4quant_trace) def rmsnorm_fp4quant( input: torch.Tensor, weight: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 822aca40..5e9eb515 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload import torch from .api_logging import flashinfer_api +from .trace.templates.attention import ( + gqa_paged_decode_trace, + single_decode_with_kv_cache_trace, + trtllm_batch_decode_trace, + xqa_batch_decode_trace, +) ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility. from .mla import ( @@ -400,7 +406,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_api +@flashinfer_api(trace=single_decode_with_kv_cache_trace) def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper: kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api + @flashinfer_api(trace=gqa_paged_decode_trace) def run( self, q: torch.Tensor, @@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra :class:`BatchDecodeWithPagedKVCacheWrapper` """ + # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper + # already decorates __init__, so decorating again produces double log entries. def __init__( self, workspace_buffer: torch.Tensor, @@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args): ) -@flashinfer_api +@flashinfer_api(trace=trtllm_batch_decode_trace) def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache( # xqa uses NHD layout -@flashinfer_api +@flashinfer_api(trace=xqa_batch_decode_trace) def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py new file mode 100644 index 00000000..1104eb6f --- /dev/null +++ b/flashinfer/fi_trace.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- + +""" +fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_ +compatible definition JSON for FlashInfer APIs. + +Every ``@flashinfer_api(trace=<template>)``-decorated function supports two +usage modes: + +Auto-dump (recommended) +----------------------- +Set environment variables **before** importing flashinfer, then run your +workload normally. No explicit ``fi_trace`` call is needed. + +.. code-block:: bash + + FLASHINFER_TRACE_DUMP=1 \\ + FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\ + python my_script.py + +Every decorated function writes a ``<name>.json`` file on its **first** call +for each unique set of const-axis values (e.g. head dimensions, vocab size). +Subsequent calls with the same shape are deduplicated — the file is written +only once per process. The output directory is created automatically. + +Explicit call (for selective or programmatic use) +------------------------------------------------- -- +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union + +# --------------------------------------------------------------------------- +# Legacy registry — kept for backwards compatibility. +# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead. +# --------------------------------------------------------------------------- + +_REGISTRY: Dict[str, Any] = {} + + +def register_fi_trace(qualname: str, spec: Any) -> None: + """Register a legacy FiTraceSpec for the function with the given qualname. + + .. deprecated:: + Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead. + """ + _REGISTRY[qualname] = spec + + +def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]: + """Build a fi_trace callable from a legacy FiTraceSpec. + + .. deprecated:: + Use ``TraceTemplate.build_fi_trace_fn`` instead. + """ + # Import the old implementation from the trace package for backwards compat. + from .trace.template import ( # noqa: PLC0415,F401 + Const, + Scalar, + Tensor, + TraceTemplate, + Var, + ) + import json # noqa: PLC0415 + import os # noqa: PLC0415 -- + """Generate a flashinfer-bench definition JSON for any FlashInfer API call. + + Parameters + ---------- + func_or_method: + A ``@flashinfer_api``-decorated function or (bound) method. + save_dir: + Directory where the JSON definition file should be written. + Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*. + **kwargs: + The same tensor arguments you would pass to the real API. + + Returns + ------- + dict + A flashinfer-bench compatible definition dictionary. + + Examples + -------- + Standalone function:: + + defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight) + + Bound method (instance.run):: + + defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v)) -- + trace_fn = getattr(actual_func, "fi_trace", None) + if trace_fn is None: + qualname = getattr(actual_func, "__qualname__", repr(actual_func)) + raise ValueError( + f"No fi_trace spec is registered for '{qualname}'. " + "Only @flashinfer_api(trace=...)-decorated functions support fi_trace." + ) + return trace_fn(save_dir=save_dir, **kwargs) diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index df6e1f72..d983f9d4 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -17,6 +17,8 @@ limitations under the License. from .core import ( convert_to_block_layout, cutlass_fused_moe, + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, gen_cutlass_fused_moe_sm120_module, gen_cutlass_fused_moe_sm103_module, gen_cutlass_fused_moe_sm100_module, @@ -64,6 +66,8 @@ __all__ = [ "WeightLayout", "convert_to_block_layout", "cutlass_fused_moe", + "interleave_moe_scales_for_sm90_mixed_gemm", -- + ), ) -# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 @flashinfer_api +def interleave_moe_scales_for_sm90_mixed_gemm( + scales: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM. + + The kernel expects scales in layout + ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural + ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer. + This helper performs the reshape + permute equivalent to TensorRT-LLM's + ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed + interleave factor of ``128 // group_size`` used for MXFP4. + + Parameters + ---------- + scales: + ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block + scales. + group_size: + MXFP4 quantization group size (default 32). -- + scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous() + ) + return tmp.reshape(e, kgs // factor, rows * factor) + + +@flashinfer_api +def interleave_moe_weights_for_sm90_mixed_gemm( + weight: torch.Tensor, + quant_type: str = "fp4", +) -> torch.Tensor: + """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM. + + The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with + ``use_w4_group_scaling=True``) expects weights in a specific interleaved + layout; without preprocessing, the LUT-based FP4→BF16 conversion reads + bytes from the wrong positions and the output diverges from a dequantized + reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent + preprocessing at weight-load time (see + ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451). + + Parameters + ---------- + weight: + ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed + two-per-byte). + quant_type: -- + ) + return out + + +# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 +@flashinfer_api(trace=cutlass_fused_moe_trace) def cutlass_fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module(): DynamicTensorSpec( input_idx, dim_idx, - get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1), - lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), + get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1), + lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens), initializers, ), ), @@ -2344,7 +2468,7 @@ def _validate_routing_replay_out( raise ValueError("routing_replay_out must be contiguous (packed row-major)") -@flashinfer_api +@flashinfer_api(trace=trtllm_bf16_moe_trace) def trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2452,7 +2576,7 @@ def trtllm_bf16_moe( return result -@flashinfer_api +@flashinfer_api(trace=trtllm_bf16_routed_moe_trace) def trtllm_bf16_routed_moe( topk_ids: torch.Tensor, hidden_states: torch.Tensor, @@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe( return result -@flashinfer_api +@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace) def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe( return result -@flashinfer_api +@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch) def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe( return result -@flashinfer_api +@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace) def trtllm_fp8_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe( return result -@flashinfer_api +@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch) def trtllm_fp4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe( ) -@flashinfer_api +@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace) def trtllm_fp4_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe( ) -@flashinfer_api +@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace) def trtllm_mxint4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py index d2cbc8b0..34916df5 100644 --- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py +++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py @@ -42,11 +42,12 @@ from typing import Optional, Tuple import torch from ...api_logging import flashinfer_api +from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace from ...utils import supported_compute_capability @supported_compute_capability([120, 121]) -@flashinfer_api +@flashinfer_api(trace=b12x_fused_moe_trace) def b12x_fused_moe( x: torch.Tensor, w1_weight: torch.Tensor, @@ -293,7 +294,7 @@ class B12xMoEWrapper: device=self.device, ) - @flashinfer_api + @flashinfer_api(trace=b12x_moe_wrapper_run_trace) def run( self, x: torch.Tensor, diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py index f6cf1b67..e266cb77 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py @@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import ( st_global_u64, scatter_add_bf16x2, ) -from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel as DenseGemmKernel, +from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel, ) diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py index e7fdae92..670b3ad8 100644 -- from .moe_utils import ( @@ -530,7 +534,7 @@ class CuteDslMoEWrapper: enable_pdl=enable_pdl, ) - @flashinfer_api + @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace) def run( self, x: torch.Tensor, @@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl( @supported_compute_capability([100, 103]) -@flashinfer_api +@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace) def cute_dsl_fused_moe_nvfp4( x: torch.Tensor, x_sf: torch.Tensor, diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py index 0cc8628e..636043db 100644 --- a/flashinfer/fused_moe/cute_dsl/tuner.py +++ b/flashinfer/fused_moe/cute_dsl/tuner.py @@ -42,8 +42,8 @@ from ...autotuner import ( TuningConfig, ) from ..utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket, ) logger = logging.getLogger(__name__) @@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner): DynamicTensorSpec( -- import torch @@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module(): @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported) -@flashinfer_api +@flashinfer_api(trace=fused_topk_deepseek_trace) def fused_topk_deepseek( scores: torch.Tensor, bias: torch.Tensor, diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py index 004271a1..91f37aa5 100644 --- a/flashinfer/fused_moe/utils.py +++ b/flashinfer/fused_moe/utils.py @@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int: return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1]) -def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: - """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1.""" - max_num_tokens = next_positive_power_of_2(max_num_tokens) - num_token_buckets = [] - m = max_num_tokens - while m >= 1: - num_token_buckets.append(m) - m //= 2 +_PHASE1_END = 256 -- @@ -106,7 +114,7 @@ TILE_V = 8 # pretranspose tile size # ============================================================================ -@flashinfer_api +@flashinfer_api(trace=gated_delta_rule_decode_trace) def gated_delta_rule_decode_pretranspose( q: torch.Tensor, k: torch.Tensor, @@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose( # ============================================================================ -@flashinfer_api +@flashinfer_api(trace=gated_delta_rule_decode_trace) def gated_delta_rule_decode( q: torch.Tensor, k: torch.Tensor, @@ -535,7 +543,7 @@ def gated_delta_rule_decode( # ============================================================================ -@flashinfer_api +@flashinfer_api(trace=gdn_mtp_trace) def gated_delta_rule_mtp( q: torch.Tensor, k: torch.Tensor, diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py index 68398d28..53fe44ce 100644 --- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py +++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py @@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel: gate_handle = load_gate_consumer.wait_and_advance() - max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1] - cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index] + cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index] valid_state = not is_first_chunk or self.use_initial_state if cutlass.const_expr(valid_state): diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py index 82dcc72b..aafcc671 100644 --- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py -- register_custom_op, @@ -95,7 +96,7 @@ def get_gdn_prefill_module(): return SimpleNamespace(gdn_prefill=gdn_prefill) -@flashinfer_api +@flashinfer_api(trace=gdn_prefill_trace) def chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index a7795beb..def82216 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -61,11 +61,11 @@ try: from flashinfer.cute_dsl.utils import is_cute_dsl_available if is_cute_dsl_available(): - from .kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel, + from .kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel, ) - _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel") + _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel") except ImportError: -- from ..utils import ( @@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16( common_check=_check_mm_bf16_problem_size, heuristic_func=_heuristic_func_mm_bf16, ) -@flashinfer_api +@flashinfer_api(trace=mm_bf16_trace) def mm_bf16( a: torch.Tensor, b: torch.Tensor, @@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16( common_check=_check_bmm_bf16_problem_size, heuristic_func=_heuristic_func_bmm_bf16, ) -@flashinfer_api +@flashinfer_api(trace=bmm_bf16_trace) def bmm_bf16( A: torch.Tensor, B: torch.Tensor, @@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig( DynamicTensorSpec( (0,), # a_tensor_index (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig( DynamicTensorSpec( (0,), # a_tensor_index (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, -- constraint_specs=( @@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module( ) -@flashinfer_api +@flashinfer_api(trace=tgv_gemm_sm100_trace) def tgv_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -1173,8 +1187,8 @@ def tgv_gemm_sm100( DynamicTensorSpec( (a_tensor_index,), (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper: True """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, backend: str = "auto" ) -> None: @@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper: self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - @flashinfer_api + @flashinfer_api(trace=segment_gemm_run_trace) def run( self, x: torch.Tensor, @@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape( return graph +# Internal helper called from mm_fp4; the user-facing mm_fp4 is already +# decorated, so decorating here would double-log the same invocation. def execute_cudnn_gemm_fp4_graph_override_shape( graph, a, @@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape( return graph +# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already +# decorated, so decorating here would double-log the same invocation. def execute_cudnn_gemm_mxfp8_graph_override_shape( graph, -- ): @@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) -@flashinfer_api +@flashinfer_api(trace=mm_fp8_trace) def mm_fp8( a: torch.Tensor, b: torch.Tensor, @@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8( common_check=_check_mm_mxfp8_problem_size, heuristic_func=_heuristic_func_mm_mxfp8, # result stored in mm_mxfp8.suitable_auto_backends ) -@flashinfer_api +@flashinfer_api(trace=mm_mxfp8_trace) def mm_mxfp8( a: torch.Tensor, b: torch.Tensor, @@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner( """ import cutlass - from .kernels.dense_blockscaled_gemm_sm120 import ( - Sm120BlockScaledDenseGemmKernel, + from .kernels.dense_blockscaled_gemm_sm120_b12x import ( + Sm120B12xBlockScaledDenseGemmKernel, ) cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype) @@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner( ] swap_ab = False for mma_tiler_mn in sm120_mma_tiler_candidates: - if not Sm120BlockScaledDenseGemmKernel.can_implement( + if not Sm120B12xBlockScaledDenseGemmKernel.can_implement( -- constraint_specs=( @@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig( common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) -@flashinfer_api +@flashinfer_api(trace=mm_fp4_trace) def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8( common_check=_check_bmm_fp8_problem_size, heuristic_func=_heuristic_func_bmm_fp8, ) -@flashinfer_api +@flashinfer_api(trace=bmm_fp8_trace) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise( {}, common_check=_check_batch_deepgemm_fp8_nt_groupwise, ) -@flashinfer_api +@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace) def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) @@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90(): return module.init() -@flashinfer_api +@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace) def fp8_blockscale_gemm_sm90( input: torch.Tensor, weight: torch.Tensor, @@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8( common_check=_check_bmm_mxfp8_problem_size, heuristic_func=_heuristic_func_bmm_mxfp8, ) -@flashinfer_api +@flashinfer_api(trace=bmm_mxfp8_trace) def bmm_mxfp8( A: torch.Tensor, B: torch.Tensor, diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py similarity index 99% rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py index c49bc815..6eee27a7 100644 --- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py +++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py @@ -1550,7 +1550,7 @@ class DenseGemmKernel: # Alias for FlashInfer integration -Sm120BlockScaledDenseGemmKernel = DenseGemmKernel +Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel class _DenseGemmLaunch: diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py -- get_cutlass_dtype, @@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel( return tensor_api -@flashinfer_api +@flashinfer_api(trace=grouped_gemm_nt_masked_trace) def grouped_gemm_nt_masked( lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py index cfde7d43..f83c8974 100644 --- a/flashinfer/gemm/routergemm.py +++ b/flashinfer/gemm/routergemm.py @@ -1,4 +1,8 @@ from ..api_logging import flashinfer_api +from ..trace.templates.gemm import ( + mm_M1_16_K7168_N256_trace, + tinygemm_bf16_trace, +) from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module import functools from types import SimpleNamespace @@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128( @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks) -@flashinfer_api +@flashinfer_api(trace=mm_M1_16_K7168_N256_trace) def mm_M1_16_K7168_N256( mat_a: torch.Tensor, mat_b: torch.Tensor, @@ -324,7 +328,7 @@ def get_tinygemm2_module(): @backend_requirement({}, common_check=_tinygemm_bf16_shape_checks) -@flashinfer_api +@flashinfer_api(trace=tinygemm_bf16_trace) def tinygemm_bf16( input: torch.Tensor, weight: torch.Tensor, diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 7f36a314..8378e0ab 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module from .comm import gen_vllm_comm_module as gen_vllm_comm_module from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module +from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module from .dsv3_optimizations import ( gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, ) diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 46768eed..834f77f9 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -15,7 +15,13 @@ limitations under the License. -- gen_selective_state_update_sm100_module, @@ -99,7 +100,7 @@ def get_selective_state_update_module( ) -@flashinfer_api +@flashinfer_api(trace=selective_state_update_trace) def selective_state_update( state: torch.Tensor, x: torch.Tensor, diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 4e8bdd72..e27e3807 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload import torch from ..api_logging import flashinfer_api +from ..trace.templates.attention import ( + mla_paged_decode_trace, + trtllm_batch_decode_mla_trace, + xqa_batch_decode_mla_trace, +) from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader from ..jit.mla import gen_mla_module from ..utils import ( @@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper: return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api + @flashinfer_api(trace=mla_paged_decode_trace) def run( self, q_nope: torch.Tensor, @@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper: return (out, lse) if return_lse else out -@flashinfer_api +@flashinfer_api(trace=trtllm_batch_decode_mla_trace) def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, @@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError(f"Backend {backend} not supported") -@flashinfer_api +@flashinfer_api(trace=xqa_batch_decode_mla_trace) def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 0f9911a6..ba612b28 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -32,6 +32,16 @@ from typing import Optional, Union import torch from ..api_logging import flashinfer_api +from ..trace.templates.norm import ( + fused_add_rmsnorm_quant_trace, + fused_add_rmsnorm_trace, + fused_rmsnorm_silu_trace, + gemma_fused_add_rmsnorm_trace, + gemma_rmsnorm_trace, + layernorm_trace, + rmsnorm_quant_trace, + rmsnorm_trace, -- get_compute_capability, @@ -94,7 +104,7 @@ def _normalize_scale_tensor( return scale.contiguous() -@flashinfer_api +@flashinfer_api(trace=rmsnorm_trace) def rmsnorm( input: torch.Tensor, weight: torch.Tensor, @@ -165,7 +175,7 @@ def _rmsnorm_impl_fake( pass -@flashinfer_api +@flashinfer_api(trace=rmsnorm_quant_trace) @register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",)) def rmsnorm_quant( out: torch.Tensor, @@ -219,7 +229,7 @@ def _rmsnorm_quant_fake( pass -@flashinfer_api +@flashinfer_api(trace=fused_add_rmsnorm_trace) @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual")) def fused_add_rmsnorm( input: torch.Tensor, @@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake( pass -@flashinfer_api +@flashinfer_api(trace=fused_add_rmsnorm_quant_trace) @register_custom_op( "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual") ) @@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake( pass -@flashinfer_api +@flashinfer_api(trace=gemma_rmsnorm_trace) def gemma_rmsnorm( input: torch.Tensor, weight: torch.Tensor, @@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake( pass -@flashinfer_api +@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace) @register_custom_op( "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual") ) @@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake( pass -@flashinfer_api +@flashinfer_api(trace=layernorm_trace) @register_custom_op("flashinfer::layernorm", mutates_args=()) def layernorm( input: torch.Tensor, @@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype): ) -@flashinfer_api +@flashinfer_api(trace=fused_rmsnorm_silu_trace) def fused_rmsnorm_silu( input: torch.Tensor, weight: torch.Tensor, diff --git a/flashinfer/page.py b/flashinfer/page.py index 12ea3613..7fb33cf3 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union import torch from .api_logging import flashinfer_api +from .trace.templates.page import ( + append_paged_kv_cache_trace, + append_paged_mla_kv_cache_trace, +) from .jit.page import gen_page_module from .utils import ( TensorLayout, @@ -222,7 +226,7 @@ def get_seq_lens( ) -@flashinfer_api +@flashinfer_api(trace=append_paged_mla_kv_cache_trace) def append_paged_mla_kv_cache( append_ckv: torch.Tensor, append_kpe: torch.Tensor, @@ -272,7 +276,7 @@ def append_paged_mla_kv_cache( ) -@flashinfer_api +@flashinfer_api(trace=append_paged_kv_cache_trace) def append_paged_kv_cache( append_key: torch.Tensor, append_value: torch.Tensor, diff --git a/flashinfer/pod.py b/flashinfer/pod.py index fe2e36c1..4fa2d9bf 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union import torch from .api_logging import flashinfer_api +from .trace.templates.attention import ( + batch_pod_with_paged_kv_cache_run_trace, + pod_with_paged_kv_cache_run_trace, +) from .jit import gen_pod_module, gen_batch_pod_module from .page import get_seq_lens from .prefill import get_batch_prefill_module @@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper: begin_forward = plan - @flashinfer_api + @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace) def run( self, # Main params (prefill and decode) @@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper: begin_forward = plan - @flashinfer_api + @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace) def run( self, # Main params (prefill and decode) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 4ec6a29e..d491dd35 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload import torch from .api_logging import flashinfer_api +from .trace.templates.attention import ( + gqa_paged_prefill_trace, + gqa_ragged_prefill_trace, + single_prefill_with_kv_cache_trace, + trtllm_batch_context_trace, +) +from .trace.templates.gemm import ( + fmha_v2_prefill_deepseek_trace, + trtllm_ragged_attention_deepseek_trace, -- gen_customize_batch_prefill_module, @@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_api +@flashinfer_api(trace=single_prefill_with_kv_cache_trace) def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper: skip_softmax_threshold_scale_factor: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api + @flashinfer_api(trace=gqa_paged_prefill_trace) def run( self, q: torch.Tensor, @@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api + @flashinfer_api(trace=gqa_ragged_prefill_trace) def run( self, q: torch.Tensor, @@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module(): return op -@flashinfer_api +@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace) def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek( skip_softmax_threshold_scale_factor: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + backend: str = "trtllm-gen", ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters @@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek( output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]] lse : Optional[torch.Tensor] lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]] + backend : str + Attention backend to use. "trtllm-gen" (default) or "cute-dsl". + When backend="cute-dsl", query/key/value/out tensors must be + front-padded with max_seq_len rows of valid GPU memory before + index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details). -- "lse assumed not None beyond this point when return_lse is True" @@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek( return out -@flashinfer_api +@flashinfer_api(trace=trtllm_batch_context_trace) def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module(): return gen_trtllm_fmha_v2_sm120_module().build_and_load() -@flashinfer_api +@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace) def fmha_v2_prefill_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module( return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load() -@flashinfer_api +@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace) def trtllm_fmha_v2_prefill( qkv: Union[ torch.Tensor, diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py index 4cd5cd34..84f7ade6 100644 --- a/flashinfer/quantization/fp4_quantization.py +++ b/flashinfer/quantization/fp4_quantization.py @@ -21,6 +21,12 @@ from typing import List, Optional, Tuple import torch from ..api_logging import flashinfer_api +from ..trace.templates.quantize import ( + fp4_quantize_trace, + mxfp4_quantize_trace, + nvfp4_kv_quantize_trace, + nvfp4_quantize_trace, +) from ..jit import JitSpec from ..jit import env as jit_env from ..jit import ( @@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"): ) -@flashinfer_api +@flashinfer_api(trace=fp4_quantize_trace) def fp4_quantize( input: torch.Tensor, global_scale: Optional[torch.Tensor] = None, @@ -923,7 +929,7 @@ def shuffle_matrix_sf_a( return block_scale_interleave(w_shuffled) -@flashinfer_api +@flashinfer_api(trace=nvfp4_quantize_trace) def nvfp4_quantize( a, a_global_sf, @@ -1024,7 +1030,7 @@ def nvfp4_quantize( return a_fp4, a_sf -@flashinfer_api +@flashinfer_api(trace=mxfp4_quantize_trace) def mxfp4_quantize( a: torch.Tensor, backend: str = "cuda", @@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale): @backend_requirement({}, common_check=_nvfp4_kv_quant_check) -@flashinfer_api +@flashinfer_api(trace=nvfp4_kv_quantize_trace) def nvfp4_kv_quantize( input: torch.Tensor, global_scale: torch.Tensor, diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py index f2c9f412..49e13a8b 100644 --- a/flashinfer/quantization/fp8_quantization.py +++ b/flashinfer/quantization/fp8_quantization.py @@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple import torch from ..api_logging import flashinfer_api +from ..trace.templates.quantize import mxfp8_quantize_trace from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module from ..utils import ( device_support_pdl, @@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module(): ) -@flashinfer_api +@flashinfer_api(trace=mxfp8_quantize_trace) def mxfp8_quantize( input: torch.Tensor, is_sf_swizzled_layout: bool = True, diff --git a/flashinfer/rope.py b/flashinfer/rope.py index d39d2e07..df5c7d4d 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -20,6 +20,21 @@ from typing import Optional, Tuple import torch from .api_logging import flashinfer_api +from .trace.templates.rope import ( + apply_llama31_rope_inplace_trace, + apply_llama31_rope_pos_ids_inplace_trace, + apply_llama31_rope_pos_ids_trace, + apply_llama31_rope_trace, + apply_rope_inplace_trace, + apply_rope_pos_ids_inplace_trace, + apply_rope_pos_ids_trace, + apply_rope_trace, -- @@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids( pass -@flashinfer_api +@flashinfer_api(trace=apply_rope_inplace_trace) def apply_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -502,7 +517,7 @@ def apply_rope_inplace( ) -@flashinfer_api +@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace) def apply_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, @@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace( ) -@flashinfer_api +@flashinfer_api(trace=apply_llama31_rope_inplace_trace) def apply_llama31_rope_inplace( q: torch.Tensor, k: torch.Tensor, ... (truncated -- see full diff via the command above)Summary of API changes:
Decorator semantic addition (backward-compatible):
@flashinfer_apinow accepts an optionaltrace=<TraceTemplate>keyword. Bare@flashinfer_apistill works. Existing call sites of decorated functions are unaffected. Most of the diff above is mechanical rewrites of existing@flashinfer_apito@flashinfer_api(trace=...), plus the newflashinfer/trace/package andfi_trace.pyfor flashinfer-bench JSON dumps.New public APIs (7):
flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size, decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace, decode_cp_a2a_alltoall}— DCP all-to-all for context-parallel attention reduction (feat: Add DCP All-to-All kernel for context-parallel attention reduction #2951).flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm, interleave_moe_weights_for_sm90_mixed_gemm}— SM90 mixed-input MoE GEMM helpers (perf: optimize MXFP4xBF16 & INT4xFP8 CUTLASS MoE backend for SM90 #3084).flashinfer.comm.run_mixed_comm— combinations of allreduce / allgather / reducescatter (Add support for the combinations of allreduce, allgather, and reducescatter #2563).New
@flashinfer_api-decorated wrapper init:SegmentGEMMWrapper.__init__is now decorated. Previously the class itself was undecorated;run()already was. No call-site change.Backward-compatible signature additions (defaults preserve old behavior):
top_k_page_table_transform:+dsa_graph_safe: bool = False,+row_starts: Optional[torch.Tensor] = None(feat: Addrow_startsanddsa_graph_safeto topk #3133).top_k_ragged_transform: same two new params (feat: Addrow_startsanddsa_graph_safeto topk #3133).trtllm_ragged_attention_deepseek:+backend: str = "trtllm-gen"(cute-dsl backend selection).No breaking signature changes to any
@flashinfer_apifunction. Net public surface delta: +7 functions, +1 newly-decorated__init__, 0 removals.Module reorganization to flag (not
@flashinfer_api, but in public re-export):flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py→dense_blockscaled_gemm_sm120_b12x.pySm120BlockScaledDenseGemmKernel→Sm120B12xBlockScaledDenseGemmKernelflashinfer/gemm/__init__.pyupdated to the new name only — direct importers of the old name break. Decision needed: ship as breaking, or add a deprecation alias.Internal autotuner helper rename (not public, but used by downstream extensions):
get_last_power_of_2_num_tokens_buckets→get_hybrid_num_tokens_bucketslast_positive_power_of_2→map_to_hybrid_bucket/map_to_hybrid_bucket_uncappedSummary by CodeRabbit