|
9 | 9 | import logging |
10 | 10 | import os |
11 | 11 | import redis |
| 12 | +from six.moves import queue |
12 | 13 | import sys |
13 | 14 | import threading |
14 | 15 | import time |
|
68 | 69 | logger = logging.getLogger(__name__) |
69 | 70 |
|
70 | 71 |
|
71 | | -# Visible for testing. |
72 | | -def _unhandled_error_handler(e: Exception): |
73 | | - logger.error("Unhandled error (suppress with " |
74 | | - "RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e)) |
75 | | - |
76 | | - |
77 | 72 | class Worker: |
78 | 73 | """A class used to define the control flow of a worker process. |
79 | 74 |
|
@@ -282,14 +277,6 @@ def put_object(self, value, object_ref=None): |
282 | 277 | self.core_worker.put_serialized_object( |
283 | 278 | serialized_value, object_ref=object_ref)) |
284 | 279 |
|
285 | | - def raise_errors(self, data_metadata_pairs, object_refs): |
286 | | - context = self.get_serialization_context() |
287 | | - out = context.deserialize_objects(data_metadata_pairs, object_refs) |
288 | | - if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: |
289 | | - return |
290 | | - for e in out: |
291 | | - _unhandled_error_handler(e) |
292 | | - |
293 | 280 | def deserialize_objects(self, data_metadata_pairs, object_refs): |
294 | 281 | context = self.get_serialization_context() |
295 | 282 | return context.deserialize_objects(data_metadata_pairs, object_refs) |
@@ -867,6 +854,13 @@ def custom_excepthook(type, value, tb): |
867 | 854 |
|
868 | 855 | sys.excepthook = custom_excepthook |
869 | 856 |
|
| 857 | +# The last time we raised a TaskError in this process. We use this value to |
| 858 | +# suppress redundant error messages pushed from the workers. |
| 859 | +last_task_error_raise_time = 0 |
| 860 | + |
| 861 | +# The max amount of seconds to wait before printing out an uncaught error. |
| 862 | +UNCAUGHT_ERROR_GRACE_PERIOD = 5 |
| 863 | + |
870 | 864 |
|
871 | 865 | def print_logs(redis_client, threads_stopped, job_id): |
872 | 866 | """Prints log messages from workers on all of the nodes. |
@@ -1017,14 +1011,51 @@ def color_for(data: Dict[str, str]) -> str: |
1017 | 1011 | file=print_file) |
1018 | 1012 |
|
1019 | 1013 |
|
1020 | | -def listen_error_messages_raylet(worker, threads_stopped): |
| 1014 | +def print_error_messages_raylet(task_error_queue, threads_stopped): |
| 1015 | + """Prints message received in the given output queue. |
| 1016 | +
|
| 1017 | + This checks periodically if any un-raised errors occurred in the |
| 1018 | + background. |
| 1019 | +
|
| 1020 | + Args: |
| 1021 | + task_error_queue (queue.Queue): A queue used to receive errors from the |
| 1022 | + thread that listens to Redis. |
| 1023 | + threads_stopped (threading.Event): A threading event used to signal to |
| 1024 | + the thread that it should exit. |
| 1025 | + """ |
| 1026 | + |
| 1027 | + while True: |
| 1028 | + # Exit if we received a signal that we should stop. |
| 1029 | + if threads_stopped.is_set(): |
| 1030 | + return |
| 1031 | + |
| 1032 | + try: |
| 1033 | + error, t = task_error_queue.get(block=False) |
| 1034 | + except queue.Empty: |
| 1035 | + threads_stopped.wait(timeout=0.01) |
| 1036 | + continue |
| 1037 | + # Delay errors a little bit of time to attempt to suppress redundant |
| 1038 | + # messages originating from the worker. |
| 1039 | + while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): |
| 1040 | + threads_stopped.wait(timeout=1) |
| 1041 | + if threads_stopped.is_set(): |
| 1042 | + break |
| 1043 | + if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: |
| 1044 | + logger.debug(f"Suppressing error from worker: {error}") |
| 1045 | + else: |
| 1046 | + logger.error(f"Possible unhandled error from worker: {error}") |
| 1047 | + |
| 1048 | + |
| 1049 | +def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): |
1021 | 1050 | """Listen to error messages in the background on the driver. |
1022 | 1051 |
|
1023 | 1052 | This runs in a separate thread on the driver and pushes (error, time) |
1024 | 1053 | tuples to the output queue. |
1025 | 1054 |
|
1026 | 1055 | Args: |
1027 | 1056 | worker: The worker class that this thread belongs to. |
| 1057 | + task_error_queue (queue.Queue): A queue used to communicate with the |
| 1058 | + thread that prints the errors found by this thread. |
1028 | 1059 | threads_stopped (threading.Event): A threading event used to signal to |
1029 | 1060 | the thread that it should exit. |
1030 | 1061 | """ |
@@ -1063,9 +1094,8 @@ def listen_error_messages_raylet(worker, threads_stopped): |
1063 | 1094 |
|
1064 | 1095 | error_message = error_data.error_message |
1065 | 1096 | if (error_data.type == ray_constants.TASK_PUSH_ERROR): |
1066 | | - # TODO(ekl) remove task push errors entirely now that we have |
1067 | | - # the separate unhandled exception handler. |
1068 | | - pass |
| 1097 | + # Delay it a bit to see if we can suppress it |
| 1098 | + task_error_queue.put((error_message, time.time())) |
1069 | 1099 | else: |
1070 | 1100 | logger.warning(error_message) |
1071 | 1101 | except (OSError, redis.exceptions.ConnectionError) as e: |
@@ -1228,12 +1258,19 @@ def connect(node, |
1228 | 1258 | # temporarily using this implementation which constantly queries the |
1229 | 1259 | # scheduler for new error messages. |
1230 | 1260 | if mode == SCRIPT_MODE: |
| 1261 | + q = queue.Queue() |
1231 | 1262 | worker.listener_thread = threading.Thread( |
1232 | 1263 | target=listen_error_messages_raylet, |
1233 | 1264 | name="ray_listen_error_messages", |
1234 | | - args=(worker, worker.threads_stopped)) |
| 1265 | + args=(worker, q, worker.threads_stopped)) |
| 1266 | + worker.printer_thread = threading.Thread( |
| 1267 | + target=print_error_messages_raylet, |
| 1268 | + name="ray_print_error_messages", |
| 1269 | + args=(q, worker.threads_stopped)) |
1235 | 1270 | worker.listener_thread.daemon = True |
1236 | 1271 | worker.listener_thread.start() |
| 1272 | + worker.printer_thread.daemon = True |
| 1273 | + worker.printer_thread.start() |
1237 | 1274 | if log_to_driver: |
1238 | 1275 | global_worker_stdstream_dispatcher.add_handler( |
1239 | 1276 | "ray_print_logs", print_to_stdstream) |
@@ -1286,6 +1323,8 @@ def disconnect(exiting_interpreter=False): |
1286 | 1323 | worker.import_thread.join_import_thread() |
1287 | 1324 | if hasattr(worker, "listener_thread"): |
1288 | 1325 | worker.listener_thread.join() |
| 1326 | + if hasattr(worker, "printer_thread"): |
| 1327 | + worker.printer_thread.join() |
1289 | 1328 | if hasattr(worker, "logger_thread"): |
1290 | 1329 | worker.logger_thread.join() |
1291 | 1330 | worker.threads_stopped.clear() |
@@ -1397,11 +1436,13 @@ def get(object_refs, *, timeout=None): |
1397 | 1436 | raise ValueError("'object_refs' must either be an object ref " |
1398 | 1437 | "or a list of object refs.") |
1399 | 1438 |
|
| 1439 | + global last_task_error_raise_time |
1400 | 1440 | # TODO(ujvl): Consider how to allow user to retrieve the ready objects. |
1401 | 1441 | values, debugger_breakpoint = worker.get_objects( |
1402 | 1442 | object_refs, timeout=timeout) |
1403 | 1443 | for i, value in enumerate(values): |
1404 | 1444 | if isinstance(value, RayError): |
| 1445 | + last_task_error_raise_time = time.time() |
1405 | 1446 | if isinstance(value, ray.exceptions.ObjectLostError): |
1406 | 1447 | worker.core_worker.dump_object_store_memory_usage() |
1407 | 1448 | if isinstance(value, RayTaskError): |
|
0 commit comments