Logging

FlashInfer provides a logging feature to help debug issues and reproduce crashes. This document describes all available logging levels and their features.

Quick Start

Enable logging using two environment variables:

# Set logging level (0-10)
export FLASHINFER_LOGLEVEL=3

# Set log destination (default is stdout)
export FLASHINFER_LOGDEST=stdout  # or stderr, or a file path like "flashinfer.log"

Logging Levels

Level

Name

Features

Use Case

0

Disabled (Default)

No logging (zero overhead)

Production

1

Function Names

Function names only

Basic tracing

3

Inputs/Outputs

Function names + arguments + outputs with metadata

Standard debugging

5

Statistics

Level 3 + tensor statistics (min, max, mean, NaN/Inf counts)

Numerical analysis

10

Flight Recorder - Full Input/Output Dumps

Level 5 + dumps all input/output tensors to .pt (or .safetensors) files

Full Reproducibility / Debugging

Environment Variables

Main Configuration

Variable

Type

Default

Description

FLASHINFER_LOGLEVEL

int

0

Logging level (0, 1, 3, 5, 10)

FLASHINFER_LOGDEST

str

stdout

Log destination: stdout, stderr, or file path

Dump Configuration (Level 10)

When FLASHINFER_LOGLEVEL is set to 10, the following environment variables can be used to configure the dump behavior:

Variable

Type

Default

Description

FLASHINFER_DUMP_DIR

str

flashinfer_dumps

Directory to save dump files

FLASHINFER_DUMP_MAX_SIZE_GB

float

20

Maximum size of dump directory in GB

FLASHINFER_DUMP_MAX_COUNT

int

1000

Maximum number of API calls to dump

FLASHINFER_DUMP_INCLUDE

str

(empty)

Comma-separated patterns to include (fnmatch-style)

FLASHINFER_DUMP_EXCLUDE

str

(empty)

Comma-separated patterns to exclude (fnmatch-style)

FLASHINFER_DUMP_SAFETENSORS

int

0

Set to 1 to use safetensors format (no pickle, but loses stride info)

SafeTensors Format (Optional)

By default, tensors are saved using torch.save() which preserves tensor stride and contiguity information. For faster, pickle-free serialization, you can enable safetensors format:

export FLASHINFER_DUMP_SAFETENSORS=1

Warning

SafeTensors does NOT preserve tensor strides or non-contiguity. All tensors are saved as contiguous. Use the default torch.save format if stride preservation is important for your debugging.

Comparison:

Aspect

torch.save (default)

safetensors

Speed

Standard

Faster

Safety

Uses pickle

No pickle (safer)

Stride preservation

✅ Yes

❌ No (contiguous only)

File extension

.pt

.safetensors

Dependency

torch`

Requires pip install safetensors

Replay is format-agnostic: The replay command automatically detects the format based on file extension.

Dump Filtering (Include/Exclude)

Use FLASHINFER_DUMP_INCLUDE and FLASHINFER_DUMP_EXCLUDE to control which API calls are dumped. This is especially useful when running end-to-end inference with many API calls but you only care about specific ones.

Pattern Syntax (fnmatch-style):

  • * matches any number of characters

  • ? matches a single character

  • Matching is case-sensitive

  • For class methods, the function name is formatted as ClassName.method_name

Filter Logic:

  1. If FLASHINFER_DUMP_INCLUDE is set, only APIs matching at least one pattern are dumped

  2. If FLASHINFER_DUMP_EXCLUDE is set, APIs matching any pattern are skipped

  3. Both can be combined: include filter is applied first, then exclude filter

Examples:

# Only dump decode-related APIs
export FLASHINFER_DUMP_INCLUDE="*decode*"

# Dump everything except __init__ and plan methods
export FLASHINFER_DUMP_EXCLUDE="*.__init__,*.plan"

# Only dump run() methods from wrapper classes
export FLASHINFER_DUMP_INCLUDE="*Wrapper.run"

# Dump all single_* APIs except prefill
export FLASHINFER_DUMP_INCLUDE="single_*"
export FLASHINFER_DUMP_EXCLUDE="*prefill*"

# Only dump a specific wrapper's run method
export FLASHINFER_DUMP_INCLUDE="BatchDecodeWithPagedKVCacheWrapper.run"

# Dump FP8 APIs but not quantization steps
export FLASHINFER_DUMP_INCLUDE="*fp8*,*FP8*"
export FLASHINFER_DUMP_EXCLUDE="*quantize*"

Common Patterns:

Pattern

Matches

*decode*

single_decode_with_kv_cache, BatchDecodeWithPagedKVCacheWrapper.run

*Wrapper.run

BatchDecodeWithPagedKVCacheWrapper.run, BatchPrefillWithPagedKVCacheWrapper.run

*.__init__

All wrapper __init__ methods

*.plan

All wrapper plan methods

mm_fp8

Exact match for mm_fp8 function

single_*

single_decode_with_kv_cache, single_prefill_with_kv_cache

Process ID Substitution

Use %i in file paths for automatic process ID substitution (useful for multi-GPU training):

export FLASHINFER_LOGDEST="flashinfer_log_%i.txt"  # → flashinfer_log_12345.txt

Miscellaneous Notes and Examples

CUDA Graph Compatibility

Level 5 statistics are automatically skipped during CUDA graph capture to avoid synchronization issues.

# This works correctly - no synchronization errors
with torch.cuda.graph(cuda_graph):
    result = mm_fp4(a, b, scales, ...)  # Level 5 logging active
    # Statistics automatically skipped during capture

Output shows: [statistics skipped: CUDA graph capture in progress]

Process IDs for Multi-GPU Environments

# Use %i for process ID substitution
export FLASHINFER_LOGLEVEL=3
export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log"

torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py

# Creates separate logs:
# logs/flashinfer_api_12345.log (rank 0)
# logs/flashinfer_api_12346.log (rank 1)
# ...

Level 0 has zero overhead

At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead.

Flight Recorder & Replay

FlashInfer includes a “Flight Recorder” mode (Level 10) that captures inputs/outputs for reproducibility.

Dump Directory Structure

When Level 10 logging is enabled, FlashInfer creates the following structure:

FLASHINFER_DUMP_DIR/
├── session.jsonl                    # Central log: one line per event (quick scanning)
├── 20250108_143216_802_pid12345_mm_fp8_call0001/
│   ├── metadata.jsonl               # Per-dump metadata (JSONL format)
│   ├── inputs.pt                    # Input tensors (or .safetensors if enabled)
│   └── outputs.pt                   # Output tensors (or .safetensors if enabled)
├── 20250108_143216_868_pid12345_single_decode_call0001/
│   ├── metadata.jsonl
│   ├── inputs.pt                    # (or .safetensors)
│   └── outputs.pt                   # (or .safetensors)
└── ...

JSONL Format: Both session.jsonl and metadata.jsonl use JSON Lines format (one JSON object per line). This enables:

  • Crash-safe logging: Each API call appends two lines (inputs_saved, then completed)

  • Quick scanning: Use session.jsonl to browse all recorded calls without reading subdirectories

  • Streaming reads: Process records line-by-line for large sessions

Per-dump metadata.jsonl:

  • Line 1: Written before execution (execution_status: "inputs_saved")

  • Line 2: Appended after successful execution (execution_status: "completed")

If a crash occurs, only line 1 will be present, preserving the inputs for debugging.

Central session.jsonl:

One-stop log of all API calls. Use standard tools to filter and analyze:

# Enable Flight Recorder (Metadata + Tensors)
export FLASHINFER_LOGLEVEL=10
export FLASHINFER_DUMP_DIR=./my_dumps

# Run your application
python3 benchmarks/flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command --use_cupti --no_cuda_graph --num_iters 5
... output redacted ...

# Replay recorded calls
export FLASHINFER_LOGLEVEL=0 # 1 for more detailed replay results.
flashinfer replay --dir ./my_dumps
# or
python -m flashinfer replay --dir ./my_dumps

[1] nvfp4_quantize (20251204_143216_802_pid12345_nvfp4_quantize_call0001):  Passed
[2] fp4_quantize (20251204_143216_868_pid12345_fp4_quantize_call0001):  Passed
[3] nvfp4_quantize (20251204_143216_949_pid12345_nvfp4_quantize_call0002):  Passed
[4] fp4_quantize (20251204_143217_003_pid12345_fp4_quantize_call0002):  Passed
[5] mm_fp4 (20251204_143217_178_pid12345_mm_fp4_call0001):  Passed
[6] mm_fp4 (20251204_143217_346_pid12345_mm_fp4_call0002):  Passed
[7] mm_fp4 (20251204_143217_427_pid12345_mm_fp4_call0003):  Passed
[8] mm_fp4 (20251204_143217_475_pid12345_mm_fp4_call0004):  Passed
[9] mm_fp4 (20251204_143217_510_pid12345_mm_fp4_call0005):  Passed
[10] mm_fp4 (20251204_143217_551_pid12345_mm_fp4_call0006):  Passed
[11] mm_fp4 (20251204_143217_591_pid12345_mm_fp4_call0007):  Passed
[12] mm_fp4 (20251204_143217_631_pid12345_mm_fp4_call0008):  Passed
[13] mm_fp4 (20251204_143217_672_pid12345_mm_fp4_call0009):  Passed
[14] mm_fp4 (20251204_143217_708_pid12345_mm_fp4_call0010):  Passed
[15] mm_fp4 (20251204_143217_769_pid12345_mm_fp4_call0011):  Passed
[16] mm_fp4 (20251204_143217_812_pid12345_mm_fp4_call0012):  Passed
[17] mm_fp4 (20251204_143217_852_pid12345_mm_fp4_call0013):  Passed
[18] mm_fp4 (20251204_143217_904_pid12345_mm_fp4_call0014):  Passed
[19] mm_fp4 (20251204_143218_153_pid12345_mm_fp4_call0015):  Passed
[20] mm_fp4 (20251204_143218_390_pid12345_mm_fp4_call0016):  Passed
[21] mm_fp4 (20251204_143218_627_pid12345_mm_fp4_call0017):  Passed
[22] mm_fp4 (20251204_143218_862_pid12345_mm_fp4_call0018):  Passed

Summary: 22 passed, 0 failed/mismatch

Python-Based Replay Examples

The following examples demonstrate how to use Level 10 logging to dump and replay API calls programmatically using Python.

Example 1: bmm_fp8 - Simple Function Call

Producer Script (bmm_fp8_producer.py):

This script initializes tensors, calls bmm_fp8, and dumps the inputs/outputs to disk.

"""
Producer script: Run bmm_fp8 with Level 10 logging to dump tensors.

Usage:
    FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./bmm_fp8_dumps python bmm_fp8_producer.py
"""
import torch
from flashinfer import bmm_fp8

def to_float8(x, dtype=torch.float8_e4m3fn):
    """Convert tensor to FP8 with per-tensor scaling."""
    finfo = torch.finfo(dtype)
    min_val, max_val = x.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
    scale = finfo.max / amax
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()

# Parameters
b, m, n, k = 4, 64, 128, 256
input_dtype = torch.float8_e4m3fn
mat2_dtype = torch.float8_e4m3fn
res_dtype = torch.bfloat16

# Create input tensors
input_bf16 = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input_bf16, dtype=input_dtype)

# mat2: row major -> column major (transposed)
mat2_bf16 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
mat2_fp8, mat2_inv_s = to_float8(mat2_bf16, dtype=mat2_dtype)

# Pre-allocate output
res = torch.empty([b, m, n], device="cuda", dtype=res_dtype)

# Call bmm_fp8 - this will be logged/dumped at Level 10
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res, backend="cublas")

# Print a small portion of the output for verification
print("Output shape:", res.shape)
print("Output[0, :3, :3]:")
print(res[0, :3, :3])

Reproducer Script (bmm_fp8_reproducer.py):

This script loads the dumped tensors and replays the bmm_fp8 call.

"""
Reproducer script: Load dumped tensors and replay bmm_fp8.

Usage:
    python bmm_fp8_reproducer.py
"""
import torch
from pathlib import Path
from flashinfer import bmm_fp8
from flashinfer.api_logging import replay_from_dump

DUMP_DIR = "./bmm_fp8_dumps"

# Find the bmm_fp8 dump directory (should be the only one or the latest)
dump_path = Path(DUMP_DIR)
bmm_dumps = sorted([d for d in dump_path.iterdir() if d.is_dir() and "bmm_fp8" in d.name])
latest_dump = bmm_dumps[-1]  # Use the latest dump
print(f"Loading dump from: {latest_dump}")

# Use replay_from_dump to load inputs and optionally execute
result = replay_from_dump(
    str(latest_dump),
    compare_outputs=True,  # Load expected outputs for comparison
    device="cuda",
    run=False,  # We'll call the function manually below
)

# Extract the loaded arguments - args contains all positional args including the output tensor
args = result["args"]
kwargs = result["kwargs"]
expected_tensors = result.get("expected_tensors", {})

# Replay the call - args already contains (input, mat2, input_inv_s, mat2_inv_s, dtype, out)
res = bmm_fp8(*args, **kwargs)

# Print the same portion for comparison
print("Replayed output shape:", res.shape)
print("Replayed output[0, :3, :3]:")
print(res[0, :3, :3])

# Compare with expected output if available
if "result" in expected_tensors:
    expected = expected_tensors["result"]
    if torch.allclose(res, expected, rtol=1e-3, atol=1e-3):
        print("\n✅ Output matches expected result!")
    else:
        diff = (res - expected).abs().max().item()
        print(f"\n❌ Output mismatch! Max diff: {diff}")

Example 2: BatchDecodeWithPagedKVCacheWrapper - Stateful Wrapper Class

Producer Script (batch_decode_producer.py):

This script demonstrates logging with a stateful wrapper class that requires __init__, plan, and run calls.

"""
Producer script: Run BatchDecodeWithPagedKVCacheWrapper with Level 10 logging.

Usage:
    FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./batch_decode_dumps python batch_decode_producer.py
"""
import torch
import flashinfer

# Parameters
batch_size = 4
kv_len = 512
page_size = 16
num_kv_heads = 4
num_qo_heads = 32
head_dim = 128
kv_layout = "NHD"

# Create query tensor
q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda", dtype=torch.float16)

# Create paged KV cache
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]  # NHD layout
kv_data = torch.randn(*kv_shape, device="cuda", dtype=torch.float16)

# Create index tensors
kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * num_pages_per_seq
kv_indices = torch.arange(0, total_num_pages, device="cuda", dtype=torch.int32)
kv_last_page_len = torch.full(
    (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda"
)

# Create workspace and wrapper - __init__ will be logged
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)

# Plan - will be logged
wrapper.plan(
    kv_indptr,
    kv_indices,
    kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    data_type=torch.float16,
    q_data_type=torch.float16,
)

# Run - will be logged
output, lse = wrapper.run(q, kv_data, return_lse=True)

# Print a small portion of the output
print("Output shape:", output.shape)
print("Output[0, :3, :3]:")
print(output[0, :3, :3])
print("\nLSE shape:", lse.shape)
print("LSE[0, :5]:", lse[0, :5])

Reproducer Script (batch_decode_reproducer.py):

This script demonstrates replaying a sequence of stateful API calls.

"""
Reproducer script: Replay BatchDecodeWithPagedKVCacheWrapper calls.

Usage:
    python batch_decode_reproducer.py
"""
import torch
from pathlib import Path
from flashinfer.api_logging import replay_sequence

DUMP_DIR = "./batch_decode_dumps"

# replay_sequence handles stateful objects automatically via object_registry
# It will:
# 1. Replay __init__ to create the wrapper instance
# 2. Replay plan() on the same instance
# 3. Replay run() on the same instance and compare outputs
results = replay_sequence(DUMP_DIR, device="cuda")

# Print summary
passed = 0
failed = 0
for i, res in enumerate(results):
    func_name = res.get("metadata", {}).get("function_name", "unknown")
    dump_dir = Path(res.get("dump_dir", "")).name

    if "error" in res:
        print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Error: {res['error']}")
        failed += 1
    elif res.get("comparison_match", True):
        print(f"[{i+1}] {func_name} ({dump_dir}): ✅ Passed")
        passed += 1
    else:
        print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Mismatch")
        failed += 1

print(f"\nSummary: {passed} passed, {failed} failed")

# For manual inspection, you can also access individual results
# Find the 'run' call result (usually the last non-init, non-plan call)
for res in results:
    func_name = res.get("metadata", {}).get("function_name", "")
    if "run" in func_name and "execution_result" in res:
        output = res["execution_result"]
        if isinstance(output, tuple):
            output_tensor, lse = output
            print("\nReplayed output[0, :3, :3]:")
            print(output_tensor[0, :3, :3])
            print("Replayed LSE[0, :5]:", lse[0, :5])
        break

Manual Replay Without replay_from_dump

For more control, you can manually load the dumped tensors:

Note

This example assumes the default torch.save format (.pt files). If dumps were created with FLASHINFER_DUMP_SAFETENSORS=1, use safetensors.torch.load_file() instead of torch.load().

"""
Manual replay: Load tensors directly from .pt files.
"""
import json
import torch
from pathlib import Path
from flashinfer import bmm_fp8

# Path is an example, replace with the actual path.
dump_dir = Path("./bmm_fp8_dumps/20250108_103217_012_pid12345_bmm_fp8_call0001")

# Load metadata from JSONL (read last line for most complete state)
with open(dump_dir / "metadata.jsonl") as f:
    lines = [line.strip() for line in f if line.strip()]
    metadata = json.loads(lines[-1])  # Last line has completed state

print(f"Function: {metadata['function_name']}")
print(f"Module: {metadata['module']}")
print(f"Status: {metadata['execution_status']}")
print(f"Input tensors: {metadata['tensor_info']['input_tensor_keys']}")

# Load input tensors
inputs = torch.load(dump_dir / "inputs.pt", map_location="cuda")

# Load expected outputs (if execution completed successfully)
outputs_path = dump_dir / "outputs.pt"
if outputs_path.exists():
    expected = torch.load(outputs_path, map_location="cuda")
    print(f"Output tensors: {list(expected.keys())}")

# Tensors are ready to use - reconstruct the call as needed
for key, tensor in inputs.items():
    print(f"  {key}: shape={tensor.shape}, dtype={tensor.dtype}")

Scanning Session History

Use the central session.jsonl to quickly scan all recorded API calls:

"""
Scan session.jsonl for quick overview of recorded calls.
"""
import json
from pathlib import Path
from collections import Counter

dump_root = Path("./my_dumps")
session_file = dump_root / "session.jsonl"

# Read all records
records = []
with open(session_file) as f:
    for line in f:
        if line.strip():
            records.append(json.loads(line))

# Filter to completed calls only
completed = [r for r in records if r["execution_status"] == "completed"]
print(f"Total completed calls: {len(completed)}")

# Count by function name
func_counts = Counter(r["function_name"] for r in completed)
print("\nCalls by function:")
for func, count in func_counts.most_common():
    print(f"  {func}: {count}")

# Find calls that didn't complete (potential crashes)
inputs_only = [r for r in records if r["execution_status"] == "inputs_saved"]
# Group by dump_dir to find incomplete calls
completed_dirs = {r["dump_dir"] for r in completed}
incomplete = [r for r in inputs_only if r["dump_dir"] not in completed_dirs]
if incomplete:
    print(f"\n⚠️  Found {len(incomplete)} incomplete calls (potential crashes):")
    for r in incomplete:
        print(f"  - {r['function_name']} at {r['dump_dir']}")