-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path_debug.py
More file actions
51 lines (41 loc) · 1.47 KB
/
_debug.py
File metadata and controls
51 lines (41 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import logging
import re
import sys
from typing import List
from sklearn._callbacks import BaseCallback
class DebugCallback(BaseCallback):
def __init__(self, verbose=True):
self.verbose = verbose
self.formatter = logging.Formatter(
fmt="%(asctime)s %(levelname)-8s %(message)s",
)
self.log = []
self.handler = logging.StreamHandler(stream=sys.stdout)
self.handler.setFormatter(self.formatter)
self.logger = logging.getLogger("sklearn")
self.logger.setLevel(logging.DEBUG)
self.logger.addHandler(self.handler)
def add_message(self, msg):
self.log.append(msg)
if self.verbose:
self.logger.info(msg)
def on_fit_begin(self, estimator, X, y):
self.add_message("fit_begin " + str(estimator))
def check_log_expected(self, log: List[str]):
"""Check that the recored log matches expected values
Parameters
----------
log
list of regexp with the expected lines for each log entry.
"""
assert len(self.log) == len(log)
for val, expected in zip(self.log, log):
if not re.match(expected, val):
raise AssertionError(
f"Expected regexp {expected} does not match '{val}'."
)
def on_iter_end(self, **kwargs):
self.add_message(
"iter_end "
+ ", ".join(f"{key}={val}" for key, val in kwargs.items())
)