-
Notifications
You must be signed in to change notification settings - Fork 967
Expand file tree
/
Copy pathapi_logging.py
More file actions
1724 lines (1464 loc) · 65.2 KB
/
api_logging.py
File metadata and controls
1724 lines (1464 loc) · 65.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import enum
import fnmatch
import functools
import inspect
import json
import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Optional
import contextlib
import importlib
import torch
# Helper function to substitute %i with process ID in file paths
def _substitute_process_id(path: str) -> str:
"""
Replace %i with the current process ID in a path.
This is useful for multi-process/multi-GPU environments where each process
needs its own log file.
"""
if "%i" in path:
return path.replace("%i", str(os.getpid()))
return path
# Read environment variables once at module load time
_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0"))
_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout"))
# Configuration for Level 10 tensor dumping
_DUMP_DIR = os.environ.get("FLASHINFER_DUMP_DIR", "flashinfer_dumps")
_DUMP_MAX_SIZE_GB = float(os.environ.get("FLASHINFER_DUMP_MAX_SIZE_GB", "20"))
_DUMP_MAX_COUNT = int(os.environ.get("FLASHINFER_DUMP_MAX_COUNT", "1000"))
# Dump filtering: include/exclude patterns (fnmatch-style, comma-separated)
# Examples: "*decode*,*prefill*" or "BatchDecodeWrapper.run,mm_fp8"
_DUMP_INCLUDE = os.environ.get("FLASHINFER_DUMP_INCLUDE", "")
_DUMP_EXCLUDE = os.environ.get("FLASHINFER_DUMP_EXCLUDE", "")
_DUMP_INCLUDE_PATTERNS = [p.strip() for p in _DUMP_INCLUDE.split(",") if p.strip()]
_DUMP_EXCLUDE_PATTERNS = [p.strip() for p in _DUMP_EXCLUDE.split(",") if p.strip()]
# SafeTensors format option (default: use torch.save which preserves stride/contiguity)
_DUMP_SAFETENSORS = os.environ.get("FLASHINFER_DUMP_SAFETENSORS", "0") == "1"
# Global tracking for dump limits (reset per process)
_dump_count = 0
_dump_total_size_bytes = 0
_dump_call_counter = {} # Track call count per function
_session_jsonl_initialized = False # Track if session.jsonl header was written
# Create logger using Python's logging library
_logger = logging.getLogger("flashinfer.api")
def _setup_logger():
"""Set up the logger based on environment variables."""
if _API_LOG_LEVEL == 0:
# Completely disable logging for zero overhead
_logger.addHandler(logging.NullHandler())
_logger.setLevel(logging.CRITICAL + 1) # Higher than any level
return
# All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL instead
_logger.setLevel(logging.DEBUG)
# Remove any existing handlers
_logger.handlers.clear()
# Create handler based on destination
if _API_LOG_DEST == "stdout":
handler = logging.StreamHandler(sys.stdout)
elif _API_LOG_DEST == "stderr":
handler = logging.StreamHandler(sys.stderr)
else:
handler = logging.FileHandler(_API_LOG_DEST, mode="a")
# Use a simple formatter (we'll add timestamps manually to key lines)
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
_logger.addHandler(handler)
_logger.propagate = False # Don't propagate to root logger
# Initialize logger at module load time
_setup_logger()
def _get_timestamp() -> str:
"""Get current timestamp in the format [YYYY-MM-DD HH:MM:SS]."""
return datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
def _warn_dump():
"""Warn users about security implications of Level 10 logging."""
if _API_LOG_LEVEL >= 10:
print("=" * 80)
print(
"WARNING: FlashInfer API Logging is set to Level 10 (Tensor Dumping).\n"
"This will dump ALL input and outputs including tensors for FlashInfer APIs to disk in\n"
"the configured dump directory. Ensure that you are NOT processing sensitive data\n"
"or that the dump directory is secure. To disable dumping, unset FLASHINFER_LOGLEVEL or\n"
"set it to below 10. For more information, see https://docs.flashinfer.ai/logging.html"
)
print(f"Current dump directory is: {_DUMP_DIR}")
if _DUMP_SAFETENSORS:
print(
"⚠️ SAFETENSORS mode enabled: tensor stride/non-contiguity will NOT be preserved.\n"
" Tensors will be saved as contiguous. Use torch.save (default) to preserve strides."
)
if _DUMP_INCLUDE_PATTERNS:
print(f"Include filter: {_DUMP_INCLUDE_PATTERNS}")
if _DUMP_EXCLUDE_PATTERNS:
print(f"Exclude filter: {_DUMP_EXCLUDE_PATTERNS}")
print("=" * 80)
def _should_dump_function(func_name: str) -> bool:
"""
Check if a function should be dumped based on include/exclude filters.
Uses fnmatch-style patterns (wildcards: * for any chars, ? for single char).
Matching is case-sensitive.
Parameters
----------
func_name : str
The function name to check. For class methods, this is formatted as
"ClassName.method_name" (e.g., "BatchDecodeWrapper.run").
Returns
-------
bool
True if the function should be dumped, False otherwise.
Filter Logic
------------
1. If FLASHINFER_DUMP_INCLUDE is set:
- Function must match at least one include pattern
- If it doesn't match any, return False (skip dump)
2. If FLASHINFER_DUMP_EXCLUDE is set:
- If function matches any exclude pattern, return False (skip dump)
3. Otherwise, return True (dump the function)
"""
# If include patterns are specified, func must match at least one
if _DUMP_INCLUDE_PATTERNS:
if not any(fnmatch.fnmatch(func_name, pat) for pat in _DUMP_INCLUDE_PATTERNS):
return False
# If exclude patterns are specified, func must not match any
if _DUMP_EXCLUDE_PATTERNS:
if any(fnmatch.fnmatch(func_name, pat) for pat in _DUMP_EXCLUDE_PATTERNS):
return False
return True
def _append_to_jsonl(filepath: Path, record: Dict[str, Any]) -> None:
"""
Append a JSON record as a single line to a JSONL file.
Parameters
----------
filepath : Path
Path to the JSONL file
record : Dict[str, Any]
Record to append (will be serialized as single-line JSON)
"""
with open(filepath, "a") as f:
f.write(json.dumps(record) + "\n")
def _read_jsonl_last_record(filepath: Path) -> Optional[Dict[str, Any]]:
"""
Read the last record from a JSONL file.
For metadata.jsonl, this returns the most complete state (completed if available,
otherwise inputs_saved).
Parameters
----------
filepath : Path
Path to the JSONL file
Returns
-------
Optional[Dict[str, Any]]
The last record, or None if file is empty/doesn't exist
"""
if not filepath.exists():
return None
last_line = None
with open(filepath, "r") as f:
for line in f:
line = line.strip()
if line:
last_line = line
if last_line:
return json.loads(last_line)
return None
def _get_tensor_size_bytes(tensor: torch.Tensor) -> int:
"""Calculate the size of a tensor in bytes."""
return tensor.element_size() * tensor.nelement()
def _serialize_value(value: Any) -> Any:
"""
Convert a non-tensor value to a JSON-serializable format for metadata.
This function is intended for serializing non-tensor arguments/values
that are used in API input or output metadata. Tensor arguments are not handled here.
"""
try:
if isinstance(value, torch.dtype):
# Special handling for torch.dtype
return {
"type": "torch.dtype",
"value": str(value), # e.g., "torch.bfloat16"
}
elif isinstance(value, enum.Enum):
return {
"type": "enum",
"name": f"{type(value).__name__}.{value.name}",
"value": value.value,
}
elif isinstance(value, (int, float, str, bool, type(None))):
return value
elif isinstance(value, (list, tuple, dict)):
return {
"type": type(value).__name__,
"value": str(value)[:1000],
} # Truncate long structures
else:
return {
"type": type(value).__name__,
"repr": str(value)[:1000],
}
except Exception:
return {
"type": type(value).__name__,
"repr": "<not serializable>",
}
def _extract_tensors_and_metadata(
args: tuple, kwargs: dict
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""
Extract tensors and non-tensor metadata from function arguments.
Tensors are moved to CPU but preserve their stride/contiguity information.
Returns
-------
tensors : Dict[str, torch.Tensor]
Dictionary of tensor arguments with keys like "arg_0", "kwarg_name"
All tensors are on CPU with original stride preserved.
metadata : Dict[str, Any]
Dictionary of non-tensor arguments (serializable to JSON)
"""
tensors = {}
metadata = {}
# Process positional arguments
for i, arg in enumerate(args):
key = f"arg_{i}"
if isinstance(arg, torch.Tensor):
tensors[key] = arg.cpu()
else:
metadata[key] = _serialize_value(arg)
# Process keyword arguments
for key, value in kwargs.items():
kwarg_key = f"kwarg_{key}"
if isinstance(value, torch.Tensor):
tensors[kwarg_key] = value.cpu()
else:
metadata[kwarg_key] = _serialize_value(value)
return tensors, metadata
def _dump_function_inputs(
func: Callable,
func_name: str,
args: tuple,
kwargs: dict,
self_id: Optional[int] = None,
) -> Optional[str]:
"""
Dump function inputs to disk BEFORE execution (crash-safe).
This function:
1. Extracts tensors and metadata from inputs
2. Creates a timestamped directory
3. Saves inputs.pt and partial metadata.json
4. Tracks cumulative size and count limits
Parameters
----------
func : Callable
The function being called
func_name : str
Name of the function
args : tuple
Positional arguments
kwargs : dict
Keyword arguments
self_id : Optional[int]
The id() of the 'self' object if this is a method call
Returns
-------
Optional[str]
Path to the dump directory, or None if dump was skipped
"""
global _dump_count, _dump_total_size_bytes
# Check include/exclude filters first (before any work is done)
if not _should_dump_function(func_name):
_logger.debug(
f"Skipping dump for {func_name} (filtered by include/exclude patterns)"
)
return None
if _dump_count >= _DUMP_MAX_COUNT:
_logger.warning(
f"Dump limit reached ({_DUMP_MAX_COUNT} dumps). Skipping dump for {func_name}. "
f"Increase FLASHINFER_DUMP_MAX_COUNT if needed."
)
return None
try:
# Get call counter for this function
if func_name not in _dump_call_counter:
_dump_call_counter[func_name] = 0
_dump_call_counter[func_name] += 1
call_seq = _dump_call_counter[func_name]
# Create dump directory structure
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[
:-3
] # Include milliseconds
pid = os.getpid()
dump_name = f"{timestamp}_pid{pid}_{func_name}_call{call_seq:04d}"
dump_dir = Path(_DUMP_DIR) / dump_name
dump_dir.mkdir(parents=True, exist_ok=True)
# Extract tensors and metadata from inputs
input_tensors, input_metadata = _extract_tensors_and_metadata(args, kwargs)
# Calculate input size
input_size = sum(_get_tensor_size_bytes(t) for t in input_tensors.values())
# Check size limit (conservative check - only inputs for now)
max_size_bytes = _DUMP_MAX_SIZE_GB * 1024 * 1024 * 1024
if _dump_total_size_bytes + input_size > max_size_bytes:
_logger.warning(
f"Dump size limit reached ({_DUMP_MAX_SIZE_GB} GB). Skipping dump for {func_name}. "
f"Increase FLASHINFER_DUMP_MAX_SIZE_GB if needed."
)
# Clean up empty directory
dump_dir.rmdir()
return None
# Save input tensors
if input_tensors:
if _DUMP_SAFETENSORS:
# SafeTensors format: faster, no pickle, but loses stride/contiguity
try:
from safetensors.torch import save_file
# safetensors requires contiguous tensors
tensors_contiguous = {
k: v.contiguous() for k, v in input_tensors.items()
}
save_file(tensors_contiguous, str(dump_dir / "inputs.safetensors"))
except ImportError:
_logger.error(
"safetensors package not installed. "
"Install with: pip install safetensors"
)
raise
else:
# torch.save format: preserves stride/contiguity
torch.save(input_tensors, dump_dir / "inputs.pt")
# Create partial metadata (inputs only, outputs will be added later)
metadata: Dict[str, Any] = {
"function_name": func_name,
"module": func.__module__ if hasattr(func, "__module__") else "<unknown>",
"call_sequence": call_seq,
"timestamp": timestamp,
"process_id": os.getpid(),
"input_metadata": input_metadata,
"output_metadata": {}, # Placeholder, will be updated after execution
"tensor_info": {
"input_tensor_keys": list(input_tensors.keys()),
"output_tensor_keys": [], # Placeholder, will be updated after execution
"input_size_bytes": input_size,
"input_size_mb": input_size / (1024 * 1024),
},
"tensor_details": {}, # Detailed shape/dtype/stride info for reconstruction
"tensor_format": "safetensors" if _DUMP_SAFETENSORS else "torch",
"function_signature": str(inspect.signature(func))
if hasattr(inspect, "signature")
else "<unavailable>",
"versions": {
"torch": torch.__version__,
"python": sys.version,
},
"execution_status": "inputs_saved", # Will be updated to "completed" after outputs
}
# Add self_id to metadata if it is a class method call
if self_id is not None:
metadata["self_id"] = self_id
# Add tensor details for random generation fallback
for key, tensor in input_tensors.items():
metadata["tensor_details"][key] = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"stride": list(tensor.stride()),
"device": str(tensor.device),
}
# Try to get FlashInfer version
try:
from .version import __version__ as flashinfer_version
metadata["versions"]["flashinfer"] = flashinfer_version # type: ignore[index]
except Exception:
metadata["versions"]["flashinfer"] = "<unavailable>" # type: ignore[index]
# Add dump_dir to metadata for central session.jsonl reference
metadata["dump_dir"] = str(dump_dir)
# Save metadata to per-dump JSONL (first line: inputs_saved)
_append_to_jsonl(dump_dir / "metadata.jsonl", metadata)
# Append to central session.jsonl for quick scanning
session_jsonl_path = Path(_DUMP_DIR) / "session.jsonl"
_append_to_jsonl(session_jsonl_path, metadata)
# Update global tracking (only input size for now)
_dump_count += 1
_dump_total_size_bytes += input_size
_logger.debug(
f"Dumped inputs to: {dump_dir} "
f"(size: {input_size / (1024 * 1024):.2f} MB, "
f"total: {_dump_count}/{_DUMP_MAX_COUNT} dumps)"
)
return str(dump_dir)
except Exception as e:
_logger.error(f"Failed to dump function call {func_name}: {e}")
import traceback
_logger.error(traceback.format_exc())
return None
def _dump_function_outputs(dump_dir: str, result: Any) -> None:
"""
Add function outputs to an existing dump directory (crash-safe).
This function is called AFTER successful execution to append outputs
to the dump that was created before execution.
Parameters
----------
dump_dir : str
Path to the dump directory created by _dump_function_inputs
result : Any
Function return value
"""
global _dump_total_size_bytes
try:
dump_path = Path(dump_dir)
if not dump_path.exists():
_logger.error(f"Dump directory not found: {dump_dir}")
return
# Extract tensors and metadata from outputs
output_tensors = {}
output_metadata = {}
if isinstance(result, torch.Tensor):
output_tensors["result"] = result.cpu()
elif isinstance(result, tuple):
for i, item in enumerate(result):
if isinstance(item, torch.Tensor):
output_tensors[f"result_{i}"] = item.cpu()
else:
output_metadata[f"result_{i}"] = _serialize_value(item)
else:
output_metadata["result"] = _serialize_value(result)
# Calculate output size
output_size = sum(_get_tensor_size_bytes(t) for t in output_tensors.values())
# Save output tensors
if output_tensors:
if _DUMP_SAFETENSORS:
# SafeTensors format: faster, no pickle, but loses stride/contiguity
from safetensors.torch import save_file
tensors_contiguous = {
k: v.contiguous() for k, v in output_tensors.items()
}
save_file(tensors_contiguous, str(dump_path / "outputs.safetensors"))
else:
# torch.save format: preserves stride/contiguity
torch.save(output_tensors, dump_path / "outputs.pt")
# Load existing metadata from JSONL (last record) and update it
metadata_jsonl_path = dump_path / "metadata.jsonl"
metadata = _read_jsonl_last_record(metadata_jsonl_path)
if metadata is not None:
# Update with output information
metadata["output_metadata"] = output_metadata
metadata["tensor_info"]["output_tensor_keys"] = list(output_tensors.keys())
metadata["tensor_info"]["output_size_bytes"] = output_size
metadata["tensor_info"]["output_size_mb"] = output_size / (1024 * 1024)
metadata["tensor_info"]["total_size_bytes"] = (
metadata["tensor_info"]["input_size_bytes"] + output_size
)
metadata["tensor_info"]["total_size_mb"] = metadata["tensor_info"][
"total_size_bytes"
] / (1024 * 1024)
metadata["execution_status"] = "completed"
# Add output tensor details
if "tensor_details" not in metadata:
metadata["tensor_details"] = {}
for key, tensor in output_tensors.items():
metadata["tensor_details"][key] = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"stride": list(tensor.stride()),
"device": str(tensor.device),
}
# Append completion record to per-dump JSONL
_append_to_jsonl(metadata_jsonl_path, metadata)
# Append completion record to central session.jsonl
session_jsonl_path = Path(_DUMP_DIR) / "session.jsonl"
_append_to_jsonl(session_jsonl_path, metadata)
# Update global size tracking
_dump_total_size_bytes += output_size
_logger.debug(
f"Dumped outputs to: {dump_dir} "
f"(output size: {output_size / (1024 * 1024):.2f} MB, "
f"total dump size: {metadata['tensor_info']['total_size_mb']:.2f} MB)"
)
else:
_logger.error(f"metadata.jsonl not found or empty in {dump_dir}")
except Exception as e:
_logger.error(f"Failed to dump outputs to {dump_dir}: {e}")
import traceback
_logger.error(traceback.format_exc())
def _reconstruct_value(value: Any) -> Any:
"""
Reconstruct special types from metadata format.
Handles:
- torch.dtype objects
- enum.Enum objects (future)
- Other serialized types
"""
if isinstance(value, dict):
value_type = value.get("type")
if value_type == "torch.dtype":
# Reconstruct torch.dtype from string
dtype_str = value.get("value", "")
# Parse strings like "torch.bfloat16", "torch.float16", etc.
dtype_name = dtype_str.replace("torch.", "")
try:
return getattr(torch, dtype_name)
except AttributeError:
_logger.warning(f"Could not reconstruct dtype: {dtype_str}")
return value
# For other dict types, return as-is
return value
return value
def _resolve_function(module_name: str, function_name: str) -> Optional[Callable]:
"""Resolve a function from module name and function name."""
try:
module = importlib.import_module(module_name)
# Handle nested function names (e.g. Class.method)
parts = function_name.split(".")
obj: Any = module
for part in parts:
obj = getattr(obj, part)
if not callable(obj):
return None
return obj
except Exception as e:
_logger.warning(
f"Could not resolve function {module_name}.{function_name}: {e}"
)
return None
def _compare_results(
actual: Any, expected: Any, rtol: float = 1e-3, atol: float = 1e-3
) -> bool:
"""Recursively compare execution results."""
# torch.Tensor comparison
if isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor):
# Check shape
if actual.shape != expected.shape:
_logger.warning(
f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}"
)
return False
# Check dtype
if actual.dtype != expected.dtype:
_logger.warning(
f"Dtype mismatch: actual {actual.dtype} vs expected {expected.dtype}"
)
return False
# Check values; apply relative and absolute tolerance.
if not torch.allclose(actual, expected, rtol=rtol, atol=atol):
diff = (actual - expected).abs().max().item()
_logger.warning(f"Value mismatch: max diff {diff}")
return False
return True
# list/tuple comparison
elif isinstance(actual, (list, tuple)) and isinstance(expected, (list, tuple)):
if len(actual) != len(expected):
_logger.warning(
f"Length mismatch: actual {len(actual)} vs expected {len(expected)}"
)
return False
return all(
_compare_results(a, e, rtol, atol)
for a, e in zip(actual, expected, strict=True)
)
# dict comparison
elif isinstance(actual, dict) and isinstance(expected, dict):
if actual.keys() != expected.keys():
_logger.warning(
f"Key mismatch: actual {actual.keys()} vs expected {expected.keys()}"
)
return False
return all(_compare_results(actual[k], expected[k], rtol, atol) for k in actual)
# fallback for other types (including None). Just do a naive comparison.
else:
if actual != expected:
_logger.warning(f"Value mismatch: actual {actual} vs expected {expected}")
return False
return True
def replay_from_dump(
dump_dir: str,
compare_outputs: bool = False,
device: str = "cuda",
run: bool = False,
object_registry: Optional[Dict[Tuple[int, int], Any]] = None,
) -> Any:
"""
Replay a function call from a dumped directory.
This function:
1. Loads metadata.jsonl to get function info
2. Loads inputs.pt to get input tensors
3. Moves tensors to specified device (default: cuda)
4. Reconstructs the function call
5. Optionally executes the function (if run=True)
6. Optionally compares with saved outputs
Parameters
----------
dump_dir : str
Path to the dump directory
compare_outputs : bool
If True, load and compare with saved outputs
device : str
Target device for tensors. Options:
- "cuda" (default): Load to cuda:0
- "cpu": Load to CPU
- "cuda:N": Load to specific CUDA device
run : bool
If True, try to resolve and execute the function
object_registry : Optional[Dict[Tuple[int, int], Any]]
Registry of stateful objects mapped by (process_id, self_id) tuple.
This composite key ensures objects from different processes don't collide
in multi-GPU environments where different processes may have objects
at the same memory address.
Returns
-------
result : dict
Dictionary containing:
- 'args': Positional arguments (tensors on specified device)
- 'kwargs': Keyword arguments (tensors on specified device)
- 'metadata': Full metadata
- 'execution_result': Result of execution (if run=True)
- 'comparison_match': Boolean indicating if result matched expected (if run=True and compare_outputs=True)
If compare_outputs=True, also includes:
- 'expected_tensors': Expected output tensors
- 'expected_metadata': Expected output metadata
"""
dump_path = Path(dump_dir)
if not dump_path.exists():
raise FileNotFoundError(f"Dump directory not found: {dump_dir}")
# Load metadata from JSONL (last record has most complete state)
metadata_jsonl_path = dump_path / "metadata.jsonl"
if not metadata_jsonl_path.exists():
raise FileNotFoundError(f"metadata.jsonl not found in {dump_dir}")
metadata = _read_jsonl_last_record(metadata_jsonl_path)
if metadata is None:
raise ValueError(f"metadata.jsonl is empty in {dump_dir}")
func_name = metadata["function_name"]
# Load input tensors - auto-detect format (torch.save or safetensors)
inputs_pt_path = dump_path / "inputs.pt"
inputs_safetensors_path = dump_path / "inputs.safetensors"
if inputs_pt_path.exists():
input_tensors = torch.load(str(inputs_pt_path), map_location="cpu")
elif inputs_safetensors_path.exists():
try:
from safetensors.torch import load_file
input_tensors = load_file(str(inputs_safetensors_path), device="cpu")
except ImportError:
raise ImportError(
"Dump was saved with safetensors but package not installed. "
"Install with: pip install safetensors"
) from None
else:
raise FileNotFoundError(
f"Neither inputs.pt nor inputs.safetensors found in {dump_dir}"
)
# Move tensors to specified device
for key, tensor in input_tensors.items():
input_tensors[key] = tensor.to(device)
# Reconstruct args and kwargs
args = []
kwargs = {}
input_metadata = metadata.get("input_metadata", {})
# Get max arg index from both tensors and metadata
max_arg_idx = -1
for key in input_tensors.keys():
if key.startswith("arg_"):
idx = int(key.split("_")[1])
max_arg_idx = max(max_arg_idx, idx)
for key in input_metadata.keys():
if key.startswith("arg_"):
idx = int(key.split("_")[1])
max_arg_idx = max(max_arg_idx, idx)
# Reconstruct positional args in order so that we can replay
# the function call exactly as it was logged.
for i in range(max_arg_idx + 1):
key = f"arg_{i}"
if key in input_tensors:
args.append(input_tensors[key])
elif key in input_metadata:
args.append(_reconstruct_value(input_metadata[key]))
else:
# Should not happen if dump is consistent, but safeguard
_logger.warning(f"Missing argument {i} in dump.")
args.append(None)
# Add keyword arguments. Here the ordering is not important.
for key in input_tensors.keys():
if key.startswith("kwarg_"):
kwarg_name = key.replace("kwarg_", "")
kwargs[kwarg_name] = input_tensors[key]
for key in input_metadata.keys():
if key.startswith("kwarg_"):
kwarg_name = key.replace("kwarg_", "")
if kwarg_name not in kwargs: # Don't override tensor kwargs
kwargs[kwarg_name] = _reconstruct_value(input_metadata[key])
_logger.info(f"Replaying {func_name} from {dump_dir}")
_logger.info(f" Args: {len(args)}, Kwargs: {list(kwargs.keys())}")
result_dict: Dict[str, Any] = {"args": args, "kwargs": kwargs, "metadata": metadata}
# Load expected outputs if needed - auto-detect format
expected_outputs = {}
output_metadata = {}
if compare_outputs:
outputs_pt_path = dump_path / "outputs.pt"
outputs_safetensors_path = dump_path / "outputs.safetensors"
if outputs_pt_path.exists():
expected_outputs = torch.load(str(outputs_pt_path), map_location="cpu")
elif outputs_safetensors_path.exists():
try:
from safetensors.torch import load_file
expected_outputs = load_file(
str(outputs_safetensors_path), device="cpu"
)
except ImportError:
raise ImportError(
"Dump was saved with safetensors but package not installed. "
"Install with: pip install safetensors"
) from None
# Move output tensors to specified device
for key, tensor in expected_outputs.items():
expected_outputs[key] = tensor.to(device)
output_metadata = metadata.get("output_metadata", {})
result_dict["expected_tensors"] = expected_outputs
result_dict["expected_metadata"] = output_metadata
if run:
module_name = metadata.get("module")
self_id = metadata.get("self_id")
process_id = metadata.get("process_id")
func = None
obj = None
# Stateful replay logic for class methods calls.
# Necessary for wrapped classes like BatchDecodeWithPagedKVCacheWrapper.
# Use (process_id, self_id) as composite key to avoid collisions across processes.
# In multi-GPU environments, different processes may have objects with the same
# memory address (self_id), so we need to scope by process_id.
if self_id is not None:
registry_key = (process_id, self_id)
if func_name.endswith(".__init__"):
# This is a constructor call
# Resolution: Get the class and instantiate it
class_name = func_name.split(".")[
-2
] # e.g. "Wrapper.__init__" -> "Wrapper"
cls_obj = _resolve_function(module_name, class_name)
if cls_obj and callable(cls_obj):
# Instantiate: obj = Class(*args[1:], **kwargs)
# Note: args[0] is 'self' placeholder in the dump for __init__, skip it
real_args = args[1:] if len(args) > 0 else []
try:
_logger.info(
f"Instantiating {class_name} (PID: {process_id}, ID: {self_id})..."
)
# We need to handle the case where __init__ is called.
# The safest way is to just call the class constructor.
# We assume the logged args match the constructor args.
obj = cls_obj(*real_args, **kwargs)
if object_registry is not None:
object_registry[registry_key] = obj
# __init__ returns None, but effectively we returned the object
execution_result = None
result_dict["execution_result"] = execution_result
# Since we successfully "ran" (instantiated), we can mark it done
# But there is no output to compare for __init__ usually (returns None)
if compare_outputs:
result_dict["comparison_match"] = (
True # Trivial pass for __init__
)
return result_dict
except Exception as e:
_logger.error(f"Failed to instantiate {class_name}: {e}")
result_dict["execution_error"] = str(e)
return result_dict
else:
# Instance method call
if object_registry is not None and registry_key in object_registry:
obj = object_registry[registry_key]
method_name = func_name.split(".")[-1]
if hasattr(obj, method_name):
func = getattr(obj, method_name)
# args[0] is 'self' placeholder, skip it
args = args[1:] if len(args) > 0 else []
else:
_logger.warning(f"Object {obj} has no method {method_name}")
else:
_logger.warning(
f"Object (PID: {process_id}, ID: {self_id}) not found in registry."
)
if func is None:
func = _resolve_function(module_name, func_name)
if func:
try:
_logger.info(f"Executing {module_name}.{func_name}...")
execution_result = func(*args, **kwargs)
result_dict["execution_result"] = execution_result
if compare_outputs:
# Flatten execution result to dict for comparison
actual_outputs = {}
if isinstance(execution_result, torch.Tensor):
actual_outputs["result"] = execution_result
elif isinstance(execution_result, (tuple, list)):
for i, item in enumerate(execution_result):
if isinstance(item, torch.Tensor):
actual_outputs[f"result_{i}"] = item
elif isinstance(execution_result, dict):
# If result is already a dict of tensors? Unlikely for FlashInfer but possible
actual_outputs = execution_result
# Compare tensors
match = True
if expected_outputs:
match = _compare_results(actual_outputs, expected_outputs)
result_dict["comparison_match"] = match
if match:
_logger.info("Replay comparison passed!")
else:
_logger.warning("Replay comparison FAILED.")
except Exception as e:
_logger.error(f"Execution failed: {e}")
import traceback
_logger.error(traceback.format_exc())
result_dict["execution_error"] = str(e)
else:
_logger.warning(
f"Skipping execution: could not resolve {module_name}.{func_name}"
)
elif not compare_outputs:
_logger.warning(
"Automatic function resolution disabled. "
"Pass run=True to execute, or manually call function."
)
return result_dict
def replay_sequence(root_dir: str, device: str = "cuda") -> list:
"""
Replay a sequence of API calls from a root dump directory.
This function iterates through all dump directories in the root directory,
sorted by timestamp/sequence number, and replays them in order.
Parameters
----------
root_dir : str
Path to the root directory containing dump subdirectories
device : str
Target device for execution (default: "cuda")