Logging¶
FlashInfer provides a logging feature to help debug issues and reproduce crashes. This document describes all available logging levels and their features.
Quick Start¶
Enable logging using two environment variables:
# Set logging level (0-10)
export FLASHINFER_LOGLEVEL=3
# Set log destination (default is stdout)
export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log"
Logging Levels¶
Level |
Name |
Features |
Use Case |
|---|---|---|---|
0 |
Disabled (Default) |
No logging (zero overhead) |
Production |
1 |
Function Names |
Function names only |
Basic tracing |
3 |
Inputs/Outputs |
Function names + arguments + outputs with metadata |
Standard debugging |
5 |
Statistics |
Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) |
Numerical analysis |
10 |
Flight Recorder - Full Input/Output Dumps |
Level 5 + dumps all input/output tensors to |
Full Reproducibility / Debugging |
Environment Variables¶
Main Configuration¶
Variable |
Type |
Default |
Description |
|---|---|---|---|
|
int |
0 |
Logging level (0, 1, 3, 5, 10) |
|
str |
|
Log destination: |
Dump Configuration (Level 10)¶
When FLASHINFER_LOGLEVEL is set to 10, the following environment variables can be used to configure the dump behavior:
Variable |
Type |
Default |
Description |
|---|---|---|---|
|
str |
|
Directory to save dump files |
|
float |
20 |
Maximum size of dump directory in GB |
|
int |
1000 |
Maximum number of API calls to dump |
|
str |
(empty) |
Comma-separated patterns to include (fnmatch-style) |
|
str |
(empty) |
Comma-separated patterns to exclude (fnmatch-style) |
|
int |
0 |
Set to 1 to use safetensors format (no pickle, but loses stride info) |
SafeTensors Format (Optional)¶
By default, tensors are saved using torch.save() which preserves tensor stride and contiguity information.
For faster, pickle-free serialization, you can enable safetensors format:
export FLASHINFER_DUMP_SAFETENSORS=1
Warning
SafeTensors does NOT preserve tensor strides or non-contiguity.
All tensors are saved as contiguous. Use the default torch.save format
if stride preservation is important for your debugging.
Comparison:
Aspect |
torch.save (default) |
safetensors |
|---|---|---|
Speed |
Standard |
Faster |
Safety |
Uses pickle |
No pickle (safer) |
Stride preservation |
✅ Yes |
❌ No (contiguous only) |
File extension |
|
|
Dependency |
torch` |
Requires |
Replay is format-agnostic: The replay command automatically detects the format based on file extension.
Dump Filtering (Include/Exclude)¶
Use FLASHINFER_DUMP_INCLUDE and FLASHINFER_DUMP_EXCLUDE to control which API calls are dumped.
This is especially useful when running end-to-end inference with many API calls but you only care about specific ones.
Pattern Syntax (fnmatch-style):
*matches any number of characters?matches a single characterMatching is case-sensitive
For class methods, the function name is formatted as
ClassName.method_name
Filter Logic:
If
FLASHINFER_DUMP_INCLUDEis set, only APIs matching at least one pattern are dumpedIf
FLASHINFER_DUMP_EXCLUDEis set, APIs matching any pattern are skippedBoth can be combined: include filter is applied first, then exclude filter
Examples:
# Only dump decode-related APIs
export FLASHINFER_DUMP_INCLUDE="*decode*"
# Dump everything except __init__ and plan methods
export FLASHINFER_DUMP_EXCLUDE="*.__init__,*.plan"
# Only dump run() methods from wrapper classes
export FLASHINFER_DUMP_INCLUDE="*Wrapper.run"
# Dump all single_* APIs except prefill
export FLASHINFER_DUMP_INCLUDE="single_*"
export FLASHINFER_DUMP_EXCLUDE="*prefill*"
# Only dump a specific wrapper's run method
export FLASHINFER_DUMP_INCLUDE="BatchDecodeWithPagedKVCacheWrapper.run"
# Dump FP8 APIs but not quantization steps
export FLASHINFER_DUMP_INCLUDE="*fp8*,*FP8*"
export FLASHINFER_DUMP_EXCLUDE="*quantize*"
Common Patterns:
Pattern |
Matches |
|---|---|
|
|
|
|
|
All wrapper |
|
All wrapper |
|
Exact match for |
|
|
Process ID Substitution¶
Use %i in file paths for automatic process ID substitution (useful for multi-GPU training):
export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt
Miscellaneous Notes and Examples¶
CUDA Graph Compatibility¶
Level 5 statistics are automatically skipped during CUDA graph capture to avoid synchronization issues.
# This works correctly - no synchronization errors
with torch.cuda.graph(cuda_graph):
result = mm_fp4(a, b, scales, ...) # Level 5 logging active
# Statistics automatically skipped during capture
Output shows: [statistics skipped: CUDA graph capture in progress]
Process IDs for Multi-GPU Environments¶
# Use %i for process ID substitution
export FLASHINFER_LOGLEVEL=3
export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log"
torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py
# Creates separate logs:
# logs/flashinfer_api_12345.log (rank 0)
# logs/flashinfer_api_12346.log (rank 1)
# ...
Level 0 has zero overhead¶
At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead.
Flight Recorder & Replay¶
FlashInfer includes a “Flight Recorder” mode (Level 10) that captures inputs/outputs for reproducibility.
Dump Directory Structure¶
When Level 10 logging is enabled, FlashInfer creates the following structure:
FLASHINFER_DUMP_DIR/
├── session.jsonl # Central log: one line per event (quick scanning)
├── 20250108_143216_802_pid12345_mm_fp8_call0001/
│ ├── metadata.jsonl # Per-dump metadata (JSONL format)
│ ├── inputs.pt # Input tensors (or .safetensors if enabled)
│ └── outputs.pt # Output tensors (or .safetensors if enabled)
├── 20250108_143216_868_pid12345_single_decode_call0001/
│ ├── metadata.jsonl
│ ├── inputs.pt # (or .safetensors)
│ └── outputs.pt # (or .safetensors)
└── ...
JSONL Format: Both session.jsonl and metadata.jsonl use JSON Lines format
(one JSON object per line). This enables:
Crash-safe logging: Each API call appends two lines (inputs_saved, then completed)
Quick scanning: Use
session.jsonlto browse all recorded calls without reading subdirectoriesStreaming reads: Process records line-by-line for large sessions
Per-dump metadata.jsonl:
Line 1: Written before execution (
execution_status: "inputs_saved")Line 2: Appended after successful execution (
execution_status: "completed")
If a crash occurs, only line 1 will be present, preserving the inputs for debugging.
Central session.jsonl:
One-stop log of all API calls. Use standard tools to filter and analyze:
# Enable Flight Recorder (Metadata + Tensors)
export FLASHINFER_LOGLEVEL=10
export FLASHINFER_DUMP_DIR=./my_dumps
# Run your application
python3 benchmarks/flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command --use_cupti --no_cuda_graph --num_iters 5
... output redacted ...
# Replay recorded calls
export FLASHINFER_LOGLEVEL=0 # 1 for more detailed replay results.
flashinfer replay --dir ./my_dumps
# or
python -m flashinfer replay --dir ./my_dumps
[1] nvfp4_quantize (20251204_143216_802_pid12345_nvfp4_quantize_call0001): ✅ Passed
[2] fp4_quantize (20251204_143216_868_pid12345_fp4_quantize_call0001): ✅ Passed
[3] nvfp4_quantize (20251204_143216_949_pid12345_nvfp4_quantize_call0002): ✅ Passed
[4] fp4_quantize (20251204_143217_003_pid12345_fp4_quantize_call0002): ✅ Passed
[5] mm_fp4 (20251204_143217_178_pid12345_mm_fp4_call0001): ✅ Passed
[6] mm_fp4 (20251204_143217_346_pid12345_mm_fp4_call0002): ✅ Passed
[7] mm_fp4 (20251204_143217_427_pid12345_mm_fp4_call0003): ✅ Passed
[8] mm_fp4 (20251204_143217_475_pid12345_mm_fp4_call0004): ✅ Passed
[9] mm_fp4 (20251204_143217_510_pid12345_mm_fp4_call0005): ✅ Passed
[10] mm_fp4 (20251204_143217_551_pid12345_mm_fp4_call0006): ✅ Passed
[11] mm_fp4 (20251204_143217_591_pid12345_mm_fp4_call0007): ✅ Passed
[12] mm_fp4 (20251204_143217_631_pid12345_mm_fp4_call0008): ✅ Passed
[13] mm_fp4 (20251204_143217_672_pid12345_mm_fp4_call0009): ✅ Passed
[14] mm_fp4 (20251204_143217_708_pid12345_mm_fp4_call0010): ✅ Passed
[15] mm_fp4 (20251204_143217_769_pid12345_mm_fp4_call0011): ✅ Passed
[16] mm_fp4 (20251204_143217_812_pid12345_mm_fp4_call0012): ✅ Passed
[17] mm_fp4 (20251204_143217_852_pid12345_mm_fp4_call0013): ✅ Passed
[18] mm_fp4 (20251204_143217_904_pid12345_mm_fp4_call0014): ✅ Passed
[19] mm_fp4 (20251204_143218_153_pid12345_mm_fp4_call0015): ✅ Passed
[20] mm_fp4 (20251204_143218_390_pid12345_mm_fp4_call0016): ✅ Passed
[21] mm_fp4 (20251204_143218_627_pid12345_mm_fp4_call0017): ✅ Passed
[22] mm_fp4 (20251204_143218_862_pid12345_mm_fp4_call0018): ✅ Passed
Summary: 22 passed, 0 failed/mismatch
Python-Based Replay Examples¶
The following examples demonstrate how to use Level 10 logging to dump and replay API calls programmatically using Python.
Example 1: bmm_fp8 - Simple Function Call¶
Producer Script (bmm_fp8_producer.py):
This script initializes tensors, calls bmm_fp8, and dumps the inputs/outputs to disk.
"""
Producer script: Run bmm_fp8 with Level 10 logging to dump tensors.
Usage:
FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./bmm_fp8_dumps python bmm_fp8_producer.py
"""
import torch
from flashinfer import bmm_fp8
def to_float8(x, dtype=torch.float8_e4m3fn):
"""Convert tensor to FP8 with per-tensor scaling."""
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
# Parameters
b, m, n, k = 4, 64, 128, 256
input_dtype = torch.float8_e4m3fn
mat2_dtype = torch.float8_e4m3fn
res_dtype = torch.bfloat16
# Create input tensors
input_bf16 = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input_bf16, dtype=input_dtype)
# mat2: row major -> column major (transposed)
mat2_bf16 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
mat2_fp8, mat2_inv_s = to_float8(mat2_bf16, dtype=mat2_dtype)
# Pre-allocate output
res = torch.empty([b, m, n], device="cuda", dtype=res_dtype)
# Call bmm_fp8 - this will be logged/dumped at Level 10
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res, backend="cublas")
# Print a small portion of the output for verification
print("Output shape:", res.shape)
print("Output[0, :3, :3]:")
print(res[0, :3, :3])
Reproducer Script (bmm_fp8_reproducer.py):
This script loads the dumped tensors and replays the bmm_fp8 call.
"""
Reproducer script: Load dumped tensors and replay bmm_fp8.
Usage:
python bmm_fp8_reproducer.py
"""
import torch
from pathlib import Path
from flashinfer import bmm_fp8
from flashinfer.api_logging import replay_from_dump
DUMP_DIR = "./bmm_fp8_dumps"
# Find the bmm_fp8 dump directory (should be the only one or the latest)
dump_path = Path(DUMP_DIR)
bmm_dumps = sorted([d for d in dump_path.iterdir() if d.is_dir() and "bmm_fp8" in d.name])
latest_dump = bmm_dumps[-1] # Use the latest dump
print(f"Loading dump from: {latest_dump}")
# Use replay_from_dump to load inputs and optionally execute
result = replay_from_dump(
str(latest_dump),
compare_outputs=True, # Load expected outputs for comparison
device="cuda",
run=False, # We'll call the function manually below
)
# Extract the loaded arguments - args contains all positional args including the output tensor
args = result["args"]
kwargs = result["kwargs"]
expected_tensors = result.get("expected_tensors", {})
# Replay the call - args already contains (input, mat2, input_inv_s, mat2_inv_s, dtype, out)
res = bmm_fp8(*args, **kwargs)
# Print the same portion for comparison
print("Replayed output shape:", res.shape)
print("Replayed output[0, :3, :3]:")
print(res[0, :3, :3])
# Compare with expected output if available
if "result" in expected_tensors:
expected = expected_tensors["result"]
if torch.allclose(res, expected, rtol=1e-3, atol=1e-3):
print("\n✅ Output matches expected result!")
else:
diff = (res - expected).abs().max().item()
print(f"\n❌ Output mismatch! Max diff: {diff}")
Example 2: BatchDecodeWithPagedKVCacheWrapper - Stateful Wrapper Class¶
Producer Script (batch_decode_producer.py):
This script demonstrates logging with a stateful wrapper class that requires __init__, plan, and run calls.
"""
Producer script: Run BatchDecodeWithPagedKVCacheWrapper with Level 10 logging.
Usage:
FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./batch_decode_dumps python batch_decode_producer.py
"""
import torch
import flashinfer
# Parameters
batch_size = 4
kv_len = 512
page_size = 16
num_kv_heads = 4
num_qo_heads = 32
head_dim = 128
kv_layout = "NHD"
# Create query tensor
q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
# Create paged KV cache
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] # NHD layout
kv_data = torch.randn(*kv_shape, device="cuda", dtype=torch.float16)
# Create index tensors
kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * num_pages_per_seq
kv_indices = torch.arange(0, total_num_pages, device="cuda", dtype=torch.int32)
kv_last_page_len = torch.full(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda"
)
# Create workspace and wrapper - __init__ will be logged
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
# Plan - will be logged
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
data_type=torch.float16,
q_data_type=torch.float16,
)
# Run - will be logged
output, lse = wrapper.run(q, kv_data, return_lse=True)
# Print a small portion of the output
print("Output shape:", output.shape)
print("Output[0, :3, :3]:")
print(output[0, :3, :3])
print("\nLSE shape:", lse.shape)
print("LSE[0, :5]:", lse[0, :5])
Reproducer Script (batch_decode_reproducer.py):
This script demonstrates replaying a sequence of stateful API calls.
"""
Reproducer script: Replay BatchDecodeWithPagedKVCacheWrapper calls.
Usage:
python batch_decode_reproducer.py
"""
import torch
from pathlib import Path
from flashinfer.api_logging import replay_sequence
DUMP_DIR = "./batch_decode_dumps"
# replay_sequence handles stateful objects automatically via object_registry
# It will:
# 1. Replay __init__ to create the wrapper instance
# 2. Replay plan() on the same instance
# 3. Replay run() on the same instance and compare outputs
results = replay_sequence(DUMP_DIR, device="cuda")
# Print summary
passed = 0
failed = 0
for i, res in enumerate(results):
func_name = res.get("metadata", {}).get("function_name", "unknown")
dump_dir = Path(res.get("dump_dir", "")).name
if "error" in res:
print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Error: {res['error']}")
failed += 1
elif res.get("comparison_match", True):
print(f"[{i+1}] {func_name} ({dump_dir}): ✅ Passed")
passed += 1
else:
print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Mismatch")
failed += 1
print(f"\nSummary: {passed} passed, {failed} failed")
# For manual inspection, you can also access individual results
# Find the 'run' call result (usually the last non-init, non-plan call)
for res in results:
func_name = res.get("metadata", {}).get("function_name", "")
if "run" in func_name and "execution_result" in res:
output = res["execution_result"]
if isinstance(output, tuple):
output_tensor, lse = output
print("\nReplayed output[0, :3, :3]:")
print(output_tensor[0, :3, :3])
print("Replayed LSE[0, :5]:", lse[0, :5])
break
Manual Replay Without replay_from_dump¶
For more control, you can manually load the dumped tensors:
Note
This example assumes the default torch.save format (.pt files).
If dumps were created with FLASHINFER_DUMP_SAFETENSORS=1, use
safetensors.torch.load_file() instead of torch.load().
"""
Manual replay: Load tensors directly from .pt files.
"""
import json
import torch
from pathlib import Path
from flashinfer import bmm_fp8
# Path is an example, replace with the actual path.
dump_dir = Path("./bmm_fp8_dumps/20250108_103217_012_pid12345_bmm_fp8_call0001")
# Load metadata from JSONL (read last line for most complete state)
with open(dump_dir / "metadata.jsonl") as f:
lines = [line.strip() for line in f if line.strip()]
metadata = json.loads(lines[-1]) # Last line has completed state
print(f"Function: {metadata['function_name']}")
print(f"Module: {metadata['module']}")
print(f"Status: {metadata['execution_status']}")
print(f"Input tensors: {metadata['tensor_info']['input_tensor_keys']}")
# Load input tensors
inputs = torch.load(dump_dir / "inputs.pt", map_location="cuda")
# Load expected outputs (if execution completed successfully)
outputs_path = dump_dir / "outputs.pt"
if outputs_path.exists():
expected = torch.load(outputs_path, map_location="cuda")
print(f"Output tensors: {list(expected.keys())}")
# Tensors are ready to use - reconstruct the call as needed
for key, tensor in inputs.items():
print(f" {key}: shape={tensor.shape}, dtype={tensor.dtype}")
Scanning Session History¶
Use the central session.jsonl to quickly scan all recorded API calls:
"""
Scan session.jsonl for quick overview of recorded calls.
"""
import json
from pathlib import Path
from collections import Counter
dump_root = Path("./my_dumps")
session_file = dump_root / "session.jsonl"
# Read all records
records = []
with open(session_file) as f:
for line in f:
if line.strip():
records.append(json.loads(line))
# Filter to completed calls only
completed = [r for r in records if r["execution_status"] == "completed"]
print(f"Total completed calls: {len(completed)}")
# Count by function name
func_counts = Counter(r["function_name"] for r in completed)
print("\nCalls by function:")
for func, count in func_counts.most_common():
print(f" {func}: {count}")
# Find calls that didn't complete (potential crashes)
inputs_only = [r for r in records if r["execution_status"] == "inputs_saved"]
# Group by dump_dir to find incomplete calls
completed_dirs = {r["dump_dir"] for r in completed}
incomplete = [r for r in inputs_only if r["dump_dir"] not in completed_dirs]
if incomplete:
print(f"\n⚠️ Found {len(incomplete)} incomplete calls (potential crashes):")
for r in incomplete:
print(f" - {r['function_name']} at {r['dump_dir']}")