|
9 | 9 | import logging |
10 | 10 | import os |
11 | 11 | import redis |
| 12 | +from six.moves import queue |
12 | 13 | import sys |
13 | 14 | import threading |
14 | 15 | import time |
|
68 | 69 | logger = logging.getLogger(__name__) |
69 | 70 |
|
70 | 71 |
|
71 | | -# Visible for testing. |
72 | | -def _unhandled_error_handler(e: Exception): |
73 | | - logger.error("Unhandled error (suppress with " |
74 | | - "RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e)) |
75 | | - |
76 | | - |
77 | 72 | class Worker: |
78 | 73 | """A class used to define the control flow of a worker process. |
79 | 74 |
|
@@ -282,14 +277,6 @@ def put_object(self, value, object_ref=None): |
282 | 277 | self.core_worker.put_serialized_object( |
283 | 278 | serialized_value, object_ref=object_ref)) |
284 | 279 |
|
285 | | - def raise_errors(self, data_metadata_pairs, object_refs): |
286 | | - context = self.get_serialization_context() |
287 | | - out = context.deserialize_objects(data_metadata_pairs, object_refs) |
288 | | - if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: |
289 | | - return |
290 | | - for e in out: |
291 | | - _unhandled_error_handler(e) |
292 | | - |
293 | 280 | def deserialize_objects(self, data_metadata_pairs, object_refs): |
294 | 281 | context = self.get_serialization_context() |
295 | 282 | return context.deserialize_objects(data_metadata_pairs, object_refs) |
@@ -876,6 +863,13 @@ def custom_excepthook(type, value, tb): |
876 | 863 |
|
877 | 864 | sys.excepthook = custom_excepthook |
878 | 865 |
|
| 866 | +# The last time we raised a TaskError in this process. We use this value to |
| 867 | +# suppress redundant error messages pushed from the workers. |
| 868 | +last_task_error_raise_time = 0 |
| 869 | + |
| 870 | +# The max amount of seconds to wait before printing out an uncaught error. |
| 871 | +UNCAUGHT_ERROR_GRACE_PERIOD = 5 |
| 872 | + |
879 | 873 |
|
880 | 874 | def print_logs(redis_client, threads_stopped, job_id): |
881 | 875 | """Prints log messages from workers on all of the nodes. |
@@ -1026,14 +1020,51 @@ def color_for(data: Dict[str, str]) -> str: |
1026 | 1020 | file=print_file) |
1027 | 1021 |
|
1028 | 1022 |
|
1029 | | -def listen_error_messages_raylet(worker, threads_stopped): |
| 1023 | +def print_error_messages_raylet(task_error_queue, threads_stopped): |
| 1024 | + """Prints message received in the given output queue. |
| 1025 | +
|
| 1026 | + This checks periodically if any un-raised errors occurred in the |
| 1027 | + background. |
| 1028 | +
|
| 1029 | + Args: |
| 1030 | + task_error_queue (queue.Queue): A queue used to receive errors from the |
| 1031 | + thread that listens to Redis. |
| 1032 | + threads_stopped (threading.Event): A threading event used to signal to |
| 1033 | + the thread that it should exit. |
| 1034 | + """ |
| 1035 | + |
| 1036 | + while True: |
| 1037 | + # Exit if we received a signal that we should stop. |
| 1038 | + if threads_stopped.is_set(): |
| 1039 | + return |
| 1040 | + |
| 1041 | + try: |
| 1042 | + error, t = task_error_queue.get(block=False) |
| 1043 | + except queue.Empty: |
| 1044 | + threads_stopped.wait(timeout=0.01) |
| 1045 | + continue |
| 1046 | + # Delay errors a little bit of time to attempt to suppress redundant |
| 1047 | + # messages originating from the worker. |
| 1048 | + while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): |
| 1049 | + threads_stopped.wait(timeout=1) |
| 1050 | + if threads_stopped.is_set(): |
| 1051 | + break |
| 1052 | + if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: |
| 1053 | + logger.debug(f"Suppressing error from worker: {error}") |
| 1054 | + else: |
| 1055 | + logger.error(f"Possible unhandled error from worker: {error}") |
| 1056 | + |
| 1057 | + |
| 1058 | +def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): |
1030 | 1059 | """Listen to error messages in the background on the driver. |
1031 | 1060 |
|
1032 | 1061 | This runs in a separate thread on the driver and pushes (error, time) |
1033 | 1062 | tuples to the output queue. |
1034 | 1063 |
|
1035 | 1064 | Args: |
1036 | 1065 | worker: The worker class that this thread belongs to. |
| 1066 | + task_error_queue (queue.Queue): A queue used to communicate with the |
| 1067 | + thread that prints the errors found by this thread. |
1037 | 1068 | threads_stopped (threading.Event): A threading event used to signal to |
1038 | 1069 | the thread that it should exit. |
1039 | 1070 | """ |
@@ -1072,9 +1103,8 @@ def listen_error_messages_raylet(worker, threads_stopped): |
1072 | 1103 |
|
1073 | 1104 | error_message = error_data.error_message |
1074 | 1105 | if (error_data.type == ray_constants.TASK_PUSH_ERROR): |
1075 | | - # TODO(ekl) remove task push errors entirely now that we have |
1076 | | - # the separate unhandled exception handler. |
1077 | | - pass |
| 1106 | + # Delay it a bit to see if we can suppress it |
| 1107 | + task_error_queue.put((error_message, time.time())) |
1078 | 1108 | else: |
1079 | 1109 | logger.warning(error_message) |
1080 | 1110 | except (OSError, redis.exceptions.ConnectionError) as e: |
@@ -1237,12 +1267,19 @@ def connect(node, |
1237 | 1267 | # temporarily using this implementation which constantly queries the |
1238 | 1268 | # scheduler for new error messages. |
1239 | 1269 | if mode == SCRIPT_MODE: |
| 1270 | + q = queue.Queue() |
1240 | 1271 | worker.listener_thread = threading.Thread( |
1241 | 1272 | target=listen_error_messages_raylet, |
1242 | 1273 | name="ray_listen_error_messages", |
1243 | | - args=(worker, worker.threads_stopped)) |
| 1274 | + args=(worker, q, worker.threads_stopped)) |
| 1275 | + worker.printer_thread = threading.Thread( |
| 1276 | + target=print_error_messages_raylet, |
| 1277 | + name="ray_print_error_messages", |
| 1278 | + args=(q, worker.threads_stopped)) |
1244 | 1279 | worker.listener_thread.daemon = True |
1245 | 1280 | worker.listener_thread.start() |
| 1281 | + worker.printer_thread.daemon = True |
| 1282 | + worker.printer_thread.start() |
1246 | 1283 | if log_to_driver: |
1247 | 1284 | global_worker_stdstream_dispatcher.add_handler( |
1248 | 1285 | "ray_print_logs", print_to_stdstream) |
@@ -1295,6 +1332,8 @@ def disconnect(exiting_interpreter=False): |
1295 | 1332 | worker.import_thread.join_import_thread() |
1296 | 1333 | if hasattr(worker, "listener_thread"): |
1297 | 1334 | worker.listener_thread.join() |
| 1335 | + if hasattr(worker, "printer_thread"): |
| 1336 | + worker.printer_thread.join() |
1298 | 1337 | if hasattr(worker, "logger_thread"): |
1299 | 1338 | worker.logger_thread.join() |
1300 | 1339 | worker.threads_stopped.clear() |
@@ -1406,11 +1445,13 @@ def get(object_refs, *, timeout=None): |
1406 | 1445 | raise ValueError("'object_refs' must either be an object ref " |
1407 | 1446 | "or a list of object refs.") |
1408 | 1447 |
|
| 1448 | + global last_task_error_raise_time |
1409 | 1449 | # TODO(ujvl): Consider how to allow user to retrieve the ready objects. |
1410 | 1450 | values, debugger_breakpoint = worker.get_objects( |
1411 | 1451 | object_refs, timeout=timeout) |
1412 | 1452 | for i, value in enumerate(values): |
1413 | 1453 | if isinstance(value, RayError): |
| 1454 | + last_task_error_raise_time = time.time() |
1414 | 1455 | if isinstance(value, ray.exceptions.ObjectLostError): |
1415 | 1456 | worker.core_worker.dump_object_store_memory_usage() |
1416 | 1457 | if isinstance(value, RayTaskError): |
|
0 commit comments