|
9 | 9 | import logging |
10 | 10 | import os |
11 | 11 | import redis |
12 | | -from six.moves import queue |
13 | 12 | import sys |
14 | 13 | import threading |
15 | 14 | import time |
|
69 | 68 | logger = logging.getLogger(__name__) |
70 | 69 |
|
71 | 70 |
|
| 71 | +# Visible for testing. |
| 72 | +def _unhandled_error_handler(e: Exception): |
| 73 | + logger.error("Unhandled error (suppress with " |
| 74 | + "RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e)) |
| 75 | + |
| 76 | + |
72 | 77 | class Worker: |
73 | 78 | """A class used to define the control flow of a worker process. |
74 | 79 |
|
@@ -277,6 +282,14 @@ def put_object(self, value, object_ref=None): |
277 | 282 | self.core_worker.put_serialized_object( |
278 | 283 | serialized_value, object_ref=object_ref)) |
279 | 284 |
|
| 285 | + def raise_errors(self, data_metadata_pairs, object_refs): |
| 286 | + context = self.get_serialization_context() |
| 287 | + out = context.deserialize_objects(data_metadata_pairs, object_refs) |
| 288 | + if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: |
| 289 | + return |
| 290 | + for e in out: |
| 291 | + _unhandled_error_handler(e) |
| 292 | + |
280 | 293 | def deserialize_objects(self, data_metadata_pairs, object_refs): |
281 | 294 | context = self.get_serialization_context() |
282 | 295 | return context.deserialize_objects(data_metadata_pairs, object_refs) |
@@ -863,13 +876,6 @@ def custom_excepthook(type, value, tb): |
863 | 876 |
|
864 | 877 | sys.excepthook = custom_excepthook |
865 | 878 |
|
866 | | -# The last time we raised a TaskError in this process. We use this value to |
867 | | -# suppress redundant error messages pushed from the workers. |
868 | | -last_task_error_raise_time = 0 |
869 | | - |
870 | | -# The max amount of seconds to wait before printing out an uncaught error. |
871 | | -UNCAUGHT_ERROR_GRACE_PERIOD = 5 |
872 | | - |
873 | 879 |
|
874 | 880 | def print_logs(redis_client, threads_stopped, job_id): |
875 | 881 | """Prints log messages from workers on all of the nodes. |
@@ -1020,51 +1026,14 @@ def color_for(data: Dict[str, str]) -> str: |
1020 | 1026 | file=print_file) |
1021 | 1027 |
|
1022 | 1028 |
|
1023 | | -def print_error_messages_raylet(task_error_queue, threads_stopped): |
1024 | | - """Prints message received in the given output queue. |
1025 | | -
|
1026 | | - This checks periodically if any un-raised errors occurred in the |
1027 | | - background. |
1028 | | -
|
1029 | | - Args: |
1030 | | - task_error_queue (queue.Queue): A queue used to receive errors from the |
1031 | | - thread that listens to Redis. |
1032 | | - threads_stopped (threading.Event): A threading event used to signal to |
1033 | | - the thread that it should exit. |
1034 | | - """ |
1035 | | - |
1036 | | - while True: |
1037 | | - # Exit if we received a signal that we should stop. |
1038 | | - if threads_stopped.is_set(): |
1039 | | - return |
1040 | | - |
1041 | | - try: |
1042 | | - error, t = task_error_queue.get(block=False) |
1043 | | - except queue.Empty: |
1044 | | - threads_stopped.wait(timeout=0.01) |
1045 | | - continue |
1046 | | - # Delay errors a little bit of time to attempt to suppress redundant |
1047 | | - # messages originating from the worker. |
1048 | | - while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): |
1049 | | - threads_stopped.wait(timeout=1) |
1050 | | - if threads_stopped.is_set(): |
1051 | | - break |
1052 | | - if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: |
1053 | | - logger.debug(f"Suppressing error from worker: {error}") |
1054 | | - else: |
1055 | | - logger.error(f"Possible unhandled error from worker: {error}") |
1056 | | - |
1057 | | - |
1058 | | -def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): |
| 1029 | +def listen_error_messages_raylet(worker, threads_stopped): |
1059 | 1030 | """Listen to error messages in the background on the driver. |
1060 | 1031 |
|
1061 | 1032 | This runs in a separate thread on the driver and pushes (error, time) |
1062 | 1033 | tuples to the output queue. |
1063 | 1034 |
|
1064 | 1035 | Args: |
1065 | 1036 | worker: The worker class that this thread belongs to. |
1066 | | - task_error_queue (queue.Queue): A queue used to communicate with the |
1067 | | - thread that prints the errors found by this thread. |
1068 | 1037 | threads_stopped (threading.Event): A threading event used to signal to |
1069 | 1038 | the thread that it should exit. |
1070 | 1039 | """ |
@@ -1103,8 +1072,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): |
1103 | 1072 |
|
1104 | 1073 | error_message = error_data.error_message |
1105 | 1074 | if (error_data.type == ray_constants.TASK_PUSH_ERROR): |
1106 | | - # Delay it a bit to see if we can suppress it |
1107 | | - task_error_queue.put((error_message, time.time())) |
| 1075 | + # TODO(ekl) remove task push errors entirely now that we have |
| 1076 | + # the separate unhandled exception handler. |
| 1077 | + pass |
1108 | 1078 | else: |
1109 | 1079 | logger.warning(error_message) |
1110 | 1080 | except (OSError, redis.exceptions.ConnectionError) as e: |
@@ -1267,19 +1237,12 @@ def connect(node, |
1267 | 1237 | # temporarily using this implementation which constantly queries the |
1268 | 1238 | # scheduler for new error messages. |
1269 | 1239 | if mode == SCRIPT_MODE: |
1270 | | - q = queue.Queue() |
1271 | 1240 | worker.listener_thread = threading.Thread( |
1272 | 1241 | target=listen_error_messages_raylet, |
1273 | 1242 | name="ray_listen_error_messages", |
1274 | | - args=(worker, q, worker.threads_stopped)) |
1275 | | - worker.printer_thread = threading.Thread( |
1276 | | - target=print_error_messages_raylet, |
1277 | | - name="ray_print_error_messages", |
1278 | | - args=(q, worker.threads_stopped)) |
| 1243 | + args=(worker, worker.threads_stopped)) |
1279 | 1244 | worker.listener_thread.daemon = True |
1280 | 1245 | worker.listener_thread.start() |
1281 | | - worker.printer_thread.daemon = True |
1282 | | - worker.printer_thread.start() |
1283 | 1246 | if log_to_driver: |
1284 | 1247 | global_worker_stdstream_dispatcher.add_handler( |
1285 | 1248 | "ray_print_logs", print_to_stdstream) |
@@ -1332,8 +1295,6 @@ def disconnect(exiting_interpreter=False): |
1332 | 1295 | worker.import_thread.join_import_thread() |
1333 | 1296 | if hasattr(worker, "listener_thread"): |
1334 | 1297 | worker.listener_thread.join() |
1335 | | - if hasattr(worker, "printer_thread"): |
1336 | | - worker.printer_thread.join() |
1337 | 1298 | if hasattr(worker, "logger_thread"): |
1338 | 1299 | worker.logger_thread.join() |
1339 | 1300 | worker.threads_stopped.clear() |
@@ -1445,13 +1406,11 @@ def get(object_refs, *, timeout=None): |
1445 | 1406 | raise ValueError("'object_refs' must either be an object ref " |
1446 | 1407 | "or a list of object refs.") |
1447 | 1408 |
|
1448 | | - global last_task_error_raise_time |
1449 | 1409 | # TODO(ujvl): Consider how to allow user to retrieve the ready objects. |
1450 | 1410 | values, debugger_breakpoint = worker.get_objects( |
1451 | 1411 | object_refs, timeout=timeout) |
1452 | 1412 | for i, value in enumerate(values): |
1453 | 1413 | if isinstance(value, RayError): |
1454 | | - last_task_error_raise_time = time.time() |
1455 | 1414 | if isinstance(value, ray.exceptions.ObjectLostError): |
1456 | 1415 | worker.core_worker.dump_object_store_memory_usage() |
1457 | 1416 | if isinstance(value, RayTaskError): |
|
0 commit comments