Skip to content

Commit 9dc671a

Browse files
authored
Unhandled exception handler based on local ref counting (#14049)
1 parent ff1b262 commit 9dc671a

11 files changed

Lines changed: 209 additions & 68 deletions

File tree

BUILD.bazel

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,15 @@ cc_test(
702702
],
703703
)
704704

705+
cc_test(
706+
name = "memory_store_test",
707+
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
708+
deps = [
709+
":core_worker_lib",
710+
"@com_google_googletest//:gtest_main",
711+
],
712+
)
713+
705714
cc_test(
706715
name = "direct_actor_transport_test",
707716
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],

python/ray/_raylet.pyx

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,20 @@ cdef void delete_spilled_objects_handler(
724724
job_id=None)
725725

726726

727+
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
728+
with gil:
729+
worker = ray.worker.global_worker
730+
data = None
731+
metadata = None
732+
if error.HasData():
733+
data = Buffer.make(error.GetData())
734+
if error.HasMetadata():
735+
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
736+
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
737+
object_ids = [None]
738+
worker.raise_errors([(data, metadata)], object_ids)
739+
740+
727741
# This function introduces ~2-7us of overhead per call (i.e., it can be called
728742
# up to hundreds of thousands of times per second).
729743
cdef void get_py_stack(c_string* stack_out) nogil:
@@ -833,6 +847,7 @@ cdef class CoreWorker:
833847
options.spill_objects = spill_objects_handler
834848
options.restore_spilled_objects = restore_spilled_objects_handler
835849
options.delete_spilled_objects = delete_spilled_objects_handler
850+
options.unhandled_exception_handler = unhandled_exception_handler
836851
options.get_lang_stack = get_py_stack
837852
options.ref_counting_enabled = True
838853
options.is_local_mode = local_mode
@@ -1443,9 +1458,13 @@ cdef class CoreWorker:
14431458
object_ref.native())
14441459

14451460
def remove_object_ref_reference(self, ObjectRef object_ref):
1446-
# Note: faster to not release GIL for short-running op.
1447-
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
1448-
object_ref.native())
1461+
cdef:
1462+
CObjectID c_object_id = object_ref.native()
1463+
# We need to release the gil since object destruction may call the
1464+
# unhandled exception handler.
1465+
with nogil:
1466+
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
1467+
c_object_id)
14491468

14501469
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
14511470
cdef:

python/ray/includes/libcoreworker.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
250250
(void(
251251
const c_vector[c_string]&,
252252
CWorkerType) nogil) delete_spilled_objects
253+
(void(const CRayObject&) nogil) unhandled_exception_handler
253254
(void(c_string *stack_out) nogil) get_lang_stack
254255
c_bool ref_counting_enabled
255256
c_bool is_local_mode

python/ray/tests/test_failure.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,52 @@
2020
get_error_message, Semaphore)
2121

2222

23+
def test_unhandled_errors(ray_start_regular):
24+
@ray.remote
25+
def f():
26+
raise ValueError()
27+
28+
@ray.remote
29+
class Actor:
30+
def f(self):
31+
raise ValueError()
32+
33+
a = Actor.remote()
34+
num_exceptions = 0
35+
36+
def interceptor(e):
37+
nonlocal num_exceptions
38+
num_exceptions += 1
39+
40+
# Test we report unhandled exceptions.
41+
ray.worker._unhandled_error_handler = interceptor
42+
x1 = f.remote()
43+
x2 = a.f.remote()
44+
del x1
45+
del x2
46+
wait_for_condition(lambda: num_exceptions == 2)
47+
48+
# Test we don't report handled exceptions.
49+
x1 = f.remote()
50+
x2 = a.f.remote()
51+
with pytest.raises(ray.exceptions.RayError) as err: # noqa
52+
ray.get([x1, x2])
53+
del x1
54+
del x2
55+
time.sleep(1)
56+
assert num_exceptions == 2, num_exceptions
57+
58+
# Test suppression with env var works.
59+
try:
60+
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
61+
x1 = f.remote()
62+
del x1
63+
time.sleep(1)
64+
assert num_exceptions == 2, num_exceptions
65+
finally:
66+
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
67+
68+
2369
def test_failed_task(ray_start_regular, error_pubsub):
2470
@ray.remote
2571
def throw_exception_fct1():

python/ray/worker.py

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import logging
1010
import os
1111
import redis
12-
from six.moves import queue
1312
import sys
1413
import threading
1514
import time
@@ -69,6 +68,12 @@
6968
logger = logging.getLogger(__name__)
7069

7170

71+
# Visible for testing.
72+
def _unhandled_error_handler(e: Exception):
73+
logger.error("Unhandled error (suppress with "
74+
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
75+
76+
7277
class Worker:
7378
"""A class used to define the control flow of a worker process.
7479
@@ -277,6 +282,14 @@ def put_object(self, value, object_ref=None):
277282
self.core_worker.put_serialized_object(
278283
serialized_value, object_ref=object_ref))
279284

285+
def raise_errors(self, data_metadata_pairs, object_refs):
286+
context = self.get_serialization_context()
287+
out = context.deserialize_objects(data_metadata_pairs, object_refs)
288+
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
289+
return
290+
for e in out:
291+
_unhandled_error_handler(e)
292+
280293
def deserialize_objects(self, data_metadata_pairs, object_refs):
281294
context = self.get_serialization_context()
282295
return context.deserialize_objects(data_metadata_pairs, object_refs)
@@ -863,13 +876,6 @@ def custom_excepthook(type, value, tb):
863876

864877
sys.excepthook = custom_excepthook
865878

866-
# The last time we raised a TaskError in this process. We use this value to
867-
# suppress redundant error messages pushed from the workers.
868-
last_task_error_raise_time = 0
869-
870-
# The max amount of seconds to wait before printing out an uncaught error.
871-
UNCAUGHT_ERROR_GRACE_PERIOD = 5
872-
873879

874880
def print_logs(redis_client, threads_stopped, job_id):
875881
"""Prints log messages from workers on all of the nodes.
@@ -1020,51 +1026,14 @@ def color_for(data: Dict[str, str]) -> str:
10201026
file=print_file)
10211027

10221028

1023-
def print_error_messages_raylet(task_error_queue, threads_stopped):
1024-
"""Prints message received in the given output queue.
1025-
1026-
This checks periodically if any un-raised errors occurred in the
1027-
background.
1028-
1029-
Args:
1030-
task_error_queue (queue.Queue): A queue used to receive errors from the
1031-
thread that listens to Redis.
1032-
threads_stopped (threading.Event): A threading event used to signal to
1033-
the thread that it should exit.
1034-
"""
1035-
1036-
while True:
1037-
# Exit if we received a signal that we should stop.
1038-
if threads_stopped.is_set():
1039-
return
1040-
1041-
try:
1042-
error, t = task_error_queue.get(block=False)
1043-
except queue.Empty:
1044-
threads_stopped.wait(timeout=0.01)
1045-
continue
1046-
# Delay errors a little bit of time to attempt to suppress redundant
1047-
# messages originating from the worker.
1048-
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
1049-
threads_stopped.wait(timeout=1)
1050-
if threads_stopped.is_set():
1051-
break
1052-
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
1053-
logger.debug(f"Suppressing error from worker: {error}")
1054-
else:
1055-
logger.error(f"Possible unhandled error from worker: {error}")
1056-
1057-
1058-
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
1029+
def listen_error_messages_raylet(worker, threads_stopped):
10591030
"""Listen to error messages in the background on the driver.
10601031
10611032
This runs in a separate thread on the driver and pushes (error, time)
10621033
tuples to the output queue.
10631034
10641035
Args:
10651036
worker: The worker class that this thread belongs to.
1066-
task_error_queue (queue.Queue): A queue used to communicate with the
1067-
thread that prints the errors found by this thread.
10681037
threads_stopped (threading.Event): A threading event used to signal to
10691038
the thread that it should exit.
10701039
"""
@@ -1103,8 +1072,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
11031072

11041073
error_message = error_data.error_message
11051074
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
1106-
# Delay it a bit to see if we can suppress it
1107-
task_error_queue.put((error_message, time.time()))
1075+
# TODO(ekl) remove task push errors entirely now that we have
1076+
# the separate unhandled exception handler.
1077+
pass
11081078
else:
11091079
logger.warning(error_message)
11101080
except (OSError, redis.exceptions.ConnectionError) as e:
@@ -1267,19 +1237,12 @@ def connect(node,
12671237
# temporarily using this implementation which constantly queries the
12681238
# scheduler for new error messages.
12691239
if mode == SCRIPT_MODE:
1270-
q = queue.Queue()
12711240
worker.listener_thread = threading.Thread(
12721241
target=listen_error_messages_raylet,
12731242
name="ray_listen_error_messages",
1274-
args=(worker, q, worker.threads_stopped))
1275-
worker.printer_thread = threading.Thread(
1276-
target=print_error_messages_raylet,
1277-
name="ray_print_error_messages",
1278-
args=(q, worker.threads_stopped))
1243+
args=(worker, worker.threads_stopped))
12791244
worker.listener_thread.daemon = True
12801245
worker.listener_thread.start()
1281-
worker.printer_thread.daemon = True
1282-
worker.printer_thread.start()
12831246
if log_to_driver:
12841247
global_worker_stdstream_dispatcher.add_handler(
12851248
"ray_print_logs", print_to_stdstream)
@@ -1332,8 +1295,6 @@ def disconnect(exiting_interpreter=False):
13321295
worker.import_thread.join_import_thread()
13331296
if hasattr(worker, "listener_thread"):
13341297
worker.listener_thread.join()
1335-
if hasattr(worker, "printer_thread"):
1336-
worker.printer_thread.join()
13371298
if hasattr(worker, "logger_thread"):
13381299
worker.logger_thread.join()
13391300
worker.threads_stopped.clear()
@@ -1445,13 +1406,11 @@ def get(object_refs, *, timeout=None):
14451406
raise ValueError("'object_refs' must either be an object ref "
14461407
"or a list of object refs.")
14471408

1448-
global last_task_error_raise_time
14491409
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
14501410
values, debugger_breakpoint = worker.get_objects(
14511411
object_refs, timeout=timeout)
14521412
for i, value in enumerate(values):
14531413
if isinstance(value, RayError):
1454-
last_task_error_raise_time = time.time()
14551414
if isinstance(value, ray.exceptions.ObjectLostError):
14561415
worker.core_worker.dump_object_store_memory_usage()
14571416
if isinstance(value, RayTaskError):

src/ray/common/ray_object.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,20 @@ class RayObject {
9292
/// large to return directly as part of a gRPC response).
9393
bool IsInPlasmaError() const;
9494

95+
/// Mark this object as accessed before.
96+
void SetAccessed() { accessed_ = true; };
97+
98+
/// Check if this object was accessed before.
99+
bool WasAccessed() const { return accessed_; }
100+
95101
private:
96102
std::shared_ptr<Buffer> data_;
97103
std::shared_ptr<Buffer> metadata_;
98104
const std::vector<ObjectID> nested_ids_;
99105
/// Whether this class holds a data copy.
100106
bool has_data_copy_;
107+
/// Whether this object was accessed.
108+
bool accessed_ = false;
101109
};
102110

103111
} // namespace ray

src/ray/core_worker/core_worker.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
422422
return Status::OK();
423423
},
424424
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
425-
options_.check_signals));
425+
options_.check_signals, options_.unhandled_exception_handler));
426426

427427
auto check_node_alive_fn = [this](const NodeID &node_id) {
428428
auto node = gcs_client_->Nodes().Get(node_id);

src/ray/core_worker/core_worker.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct CoreWorkerOptions {
8282
spill_objects(nullptr),
8383
restore_spilled_objects(nullptr),
8484
delete_spilled_objects(nullptr),
85+
unhandled_exception_handler(nullptr),
8586
get_lang_stack(nullptr),
8687
kill_main(nullptr),
8788
ref_counting_enabled(false),
@@ -146,6 +147,8 @@ struct CoreWorkerOptions {
146147
/// Application-language callback to delete objects from external storage.
147148
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
148149
delete_spilled_objects;
150+
/// Function to call on error objects never retrieved.
151+
std::function<void(const RayObject &error)> unhandled_exception_handler;
149152
/// Language worker callback to get the current call stack.
150153
std::function<void(std::string *)> get_lang_stack;
151154
// Function that tries to interrupt the currently running Python thread.

0 commit comments

Comments
 (0)