Skip to content

Commit 9695271

Browse files
Fix issues with resuming async tasks awaiting a future (#1469)
Signed-off-by: Błażej Sowa <bsowa123@gmail.com> Signed-off-by: Nadav Elkabets <elnadav12@gmail.com> Co-authored-by: Nadav Elkabets <32939935+nadavelkabets@users.noreply.github.com>
1 parent 0ce60eb commit 9695271

File tree

6 files changed

+169
-98
lines changed

6 files changed

+169
-98
lines changed

rclpy/rclpy/executors.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import deque
1516
from concurrent.futures import ThreadPoolExecutor
1617
from contextlib import ExitStack
18+
from dataclasses import dataclass
1719
from functools import partial
1820
import inspect
1921
import os
@@ -26,6 +28,7 @@
2628
from typing import Callable
2729
from typing import ContextManager
2830
from typing import Coroutine
31+
from typing import Deque
2932
from typing import Dict
3033
from typing import Generator
3134
from typing import List
@@ -176,6 +179,12 @@ def timeout(self, timeout: float) -> None:
176179
self._timeout = timeout
177180

178181

182+
@dataclass
183+
class TaskData:
184+
source_node: 'Optional[Node]' = None
185+
source_entity: 'Optional[Entity]' = None
186+
187+
179188
class Executor(ContextManager['Executor']):
180189
"""
181190
The base class for an executor.
@@ -205,8 +214,10 @@ def __init__(self, *, context: Optional[Context] = None) -> None:
205214
self._context = get_default_context() if context is None else context
206215
self._nodes: Set[Node] = set()
207216
self._nodes_lock = RLock()
208-
# Tasks to be executed (oldest first) 3-tuple Task, Entity, Node
209-
self._tasks: List[Tuple[Task[Any], 'Optional[Entity]', Optional[Node]]] = []
217+
# all tasks that are not complete or canceled
218+
self._pending_tasks: Dict[Task, TaskData] = {}
219+
# tasks that are ready to execute
220+
self._ready_tasks: Deque[Task[Any]] = deque()
210221
self._tasks_lock = Lock()
211222
# This is triggered when wait_for_ready_callbacks should rebuild the wait list
212223
self._guard: Optional[GuardCondition] = GuardCondition(
@@ -276,11 +287,20 @@ def create_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any
276287
"""
277288
task = Task(callback, args, kwargs, executor=self)
278289
with self._tasks_lock:
279-
self._tasks.append((task, None, None))
290+
self._pending_tasks[task] = TaskData()
291+
self._call_task_in_next_spin(task)
292+
return task
293+
294+
def _call_task_in_next_spin(self, task: Task) -> None:
295+
"""
296+
Add a task to the executor to be executed in the next spin.
297+
298+
:param task: A task to be run in the executor.
299+
"""
300+
with self._tasks_lock:
301+
self._ready_tasks.append(task)
280302
if self._guard:
281303
self._guard.trigger()
282-
# Task inherits from Future
283-
return task
284304

285305
def create_future(self) -> Future:
286306
"""Create a Future object attached to the Executor."""
@@ -664,7 +684,10 @@ async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool,
664684
handler, (entity, self._guard, self._is_shutdown, self._work_tracker),
665685
executor=self)
666686
with self._tasks_lock:
667-
self._tasks.append((task, entity, node))
687+
self._pending_tasks[task] = TaskData(
688+
source_entity=entity,
689+
source_node=node
690+
)
668691
return task
669692

670693
def can_execute(self, entity: 'Entity') -> bool:
@@ -709,21 +732,25 @@ def _wait_for_ready_callbacks(
709732
nodes_to_use = self.get_nodes()
710733

711734
# Yield tasks in-progress before waiting for new work
712-
tasks = None
713735
with self._tasks_lock:
714-
tasks = list(self._tasks)
715-
if tasks:
716-
for task, entity, node in tasks:
717-
if (not task.executing() and not task.done() and
718-
(node is None or node in nodes_to_use)):
719-
yielded_work = True
720-
yield task, entity, node
721-
with self._tasks_lock:
722-
# Get rid of any tasks that are done
723-
self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks))
724-
# Get rid of any tasks that are cancelled
725-
self._tasks = list(filter(lambda t_e_n: not t_e_n[0].cancelled(), self._tasks))
726-
736+
# Get rid of any tasks that are done or cancelled
737+
for task in list(self._pending_tasks.keys()):
738+
if task.done() or task.cancelled():
739+
del self._pending_tasks[task]
740+
741+
ready_tasks_count = len(self._ready_tasks)
742+
for _ in range(ready_tasks_count):
743+
task = self._ready_tasks.popleft()
744+
task_data = self._pending_tasks[task]
745+
node = task_data.source_node
746+
if node is None or node in nodes_to_use:
747+
entity = task_data.source_entity
748+
yielded_work = True
749+
yield task, entity, node
750+
else:
751+
# Asked not to execute these tasks, so don't do them yet
752+
with self._tasks_lock:
753+
self._ready_tasks.append(task)
727754
# Gather entities that can be waited on
728755
subscriptions: List[Subscription[Any, ]] = []
729756
guards: List[GuardCondition] = []

rclpy/rclpy/task.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ def __del__(self) -> None:
6666
'The following exception was never retrieved: ' + str(self._exception),
6767
file=sys.stderr)
6868

69-
def __await__(self) -> Generator[None, None, Optional[T]]:
69+
def __await__(self) -> Generator['Future[T]', None, Optional[T]]:
7070
# Yield if the task is not finished
71-
while self._pending():
72-
yield
71+
if self._pending():
72+
# This tells the task to suspend until the future is done
73+
yield self
74+
if self._pending():
75+
raise RuntimeError('Future awaited a second time before it was done')
7376
return self.result()
7477

7578
def _pending(self) -> bool:
@@ -298,17 +301,7 @@ def __call__(self) -> None:
298301
self._executing = True
299302

300303
if inspect.iscoroutine(self._handler):
301-
# Execute a coroutine
302-
handler = self._handler
303-
try:
304-
handler.send(None)
305-
except StopIteration as e:
306-
# The coroutine finished; store the result
307-
self.set_result(e.value)
308-
self._complete_task()
309-
except Exception as e:
310-
self.set_exception(e)
311-
self._complete_task()
304+
self._execute_coroutine_step(self._handler)
312305
else:
313306
# Execute a normal function
314307
try:
@@ -322,6 +315,47 @@ def __call__(self) -> None:
322315
finally:
323316
self._task_lock.release()
324317

318+
def _execute_coroutine_step(self, coro: Coroutine[Any, Any, T]) -> None:
319+
"""Execute or resume a coroutine task."""
320+
try:
321+
result = coro.send(None)
322+
except StopIteration as e:
323+
# The coroutine finished; store the result
324+
self.set_result(e.value)
325+
self._complete_task()
326+
except Exception as e:
327+
# The coroutine raised; store the exception
328+
self.set_exception(e)
329+
self._complete_task()
330+
else:
331+
# The coroutine yielded; suspend the task until it is resumed
332+
executor = self._executor()
333+
if executor is None:
334+
raise RuntimeError(
335+
'Task tried to reschedule but no executor was set: '
336+
'tasks should only be initialized through executor.create_task()')
337+
elif isinstance(result, Future):
338+
# Schedule the task to resume when the future is done
339+
self._add_resume_callback(result, executor)
340+
elif result is None:
341+
# The coroutine yielded None, schedule the task to resume in the next spin
342+
executor._call_task_in_next_spin(self)
343+
else:
344+
raise TypeError(
345+
f'Expected coroutine to yield a Future or None, got: {type(result)}')
346+
347+
def _add_resume_callback(self, future: Future[T], executor: 'Executor') -> None:
348+
future_executor = future._executor()
349+
if future_executor is None:
350+
# The future is not associated with an executor yet, so associate it with ours
351+
future._set_executor(executor)
352+
elif future_executor is not executor:
353+
raise RuntimeError('A task can only await futures associated with the same executor')
354+
355+
# The future is associated with the same executor, so we can resume the task directly
356+
# in the done callback
357+
future.add_done_callback(lambda _: self.__call__())
358+
325359
def _complete_task(self) -> None:
326360
"""Cleanup after task finished."""
327361
self._handler = None

rclpy/src/rclpy/events_executor/events_executor.cpp

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,15 @@ pybind11::object EventsExecutor::create_task(
7575
// manual refcounting on it instead.
7676
py::handle cb_task_handle = task;
7777
cb_task_handle.inc_ref();
78-
events_queue_.Enqueue(std::bind(&EventsExecutor::IterateTask, this, cb_task_handle));
78+
call_task_in_next_spin(task);
7979
return task;
8080
}
8181

82+
void EventsExecutor::call_task_in_next_spin(pybind11::handle task)
83+
{
84+
events_queue_.Enqueue(std::bind(&EventsExecutor::IterateTask, this, task));
85+
}
86+
8287
pybind11::object EventsExecutor::create_future()
8388
{
8489
using py::literals::operator""_a;
@@ -164,8 +169,6 @@ void EventsExecutor::spin(std::optional<double> timeout_sec, bool stop_after_use
164169
throw std::runtime_error("Attempt to spin an already-spinning Executor");
165170
}
166171
stop_after_user_callback_ = stop_after_user_callback;
167-
// Any blocked tasks may have become unblocked while we weren't looking.
168-
PostOutstandingTasks();
169172
// Release the GIL while we block. Any callbacks on the events queue that want to touch Python
170173
// will need to reacquire it though.
171174
py::gil_scoped_release gil_release;
@@ -354,8 +357,6 @@ void EventsExecutor::HandleSubscriptionReady(py::handle subscription, size_t num
354357
got_none = true;
355358
}
356359
}
357-
358-
PostOutstandingTasks();
359360
}
360361

361362
void EventsExecutor::HandleAddedTimer(py::handle timer) {timers_manager_.AddTimer(timer);}
@@ -397,7 +398,6 @@ void EventsExecutor::HandleTimerReady(py::handle timer, const rcl_timer_call_inf
397398
} else if (stop_after_user_callback_) {
398399
events_queue_.Stop();
399400
}
400-
PostOutstandingTasks();
401401
}
402402

403403
void EventsExecutor::HandleAddedClient(py::handle client)
@@ -468,8 +468,6 @@ void EventsExecutor::HandleClientReady(py::handle client, size_t number_of_event
468468
}
469469
}
470470
}
471-
472-
PostOutstandingTasks();
473471
}
474472

475473
void EventsExecutor::HandleAddedService(py::handle service)
@@ -543,8 +541,6 @@ void EventsExecutor::HandleServiceReady(py::handle service, size_t number_of_eve
543541
}
544542
}
545543
}
546-
547-
PostOutstandingTasks();
548544
}
549545

550546
void EventsExecutor::HandleAddedWaitable(py::handle waitable)
@@ -810,8 +806,6 @@ void EventsExecutor::HandleWaitableReady(
810806
// execute() is an async method, we need a Task to run it
811807
create_task(execute(data));
812808
}
813-
814-
PostOutstandingTasks();
815809
}
816810

817811
void EventsExecutor::IterateTask(py::handle task)
@@ -844,26 +838,7 @@ void EventsExecutor::IterateTask(py::handle task)
844838
throw;
845839
}
846840
}
847-
} else {
848-
// Task needs more iteration. Store the handle and revisit it later after the next ready
849-
// entity which may unblock it.
850-
// TODO(bmartin427) This matches the behavior of SingleThreadedExecutor and avoids busy
851-
// looping, but I don't love it because if the task is waiting on something other than an rcl
852-
// entity (e.g. an asyncio sleep, or a Future triggered from another thread, or even another
853-
// Task), there can be arbitrarily long latency before some rcl activity causes us to go
854-
// revisit that Task.
855-
blocked_tasks_.push_back(task);
856-
}
857-
}
858-
859-
void EventsExecutor::PostOutstandingTasks()
860-
{
861-
for (auto & task : blocked_tasks_) {
862-
events_queue_.Enqueue(std::bind(&EventsExecutor::IterateTask, this, task));
863841
}
864-
// Clear the entire outstanding tasks list. Any tasks that need further iteration will re-add
865-
// themselves during IterateTask().
866-
blocked_tasks_.clear();
867842
}
868843

869844
void EventsExecutor::HandleCallbackExceptionInNodeEntity(
@@ -922,6 +897,7 @@ void define_events_executor(py::object module)
922897
.def(py::init<py::object>(), py::arg("context"))
923898
.def_property_readonly("context", &EventsExecutor::get_context)
924899
.def("create_task", &EventsExecutor::create_task, py::arg("callback"))
900+
.def("_call_task_in_next_spin", &EventsExecutor::call_task_in_next_spin, py::arg("task"))
925901
.def("create_future", &EventsExecutor::create_future)
926902
.def("shutdown", &EventsExecutor::shutdown, py::arg("timeout_sec") = py::none())
927903
.def("add_node", &EventsExecutor::add_node, py::arg("node"))

rclpy/src/rclpy/events_executor/events_executor.hpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class EventsExecutor
6767
pybind11::object get_context() const {return rclpy_context_;}
6868
pybind11::object create_task(
6969
pybind11::object callback, pybind11::args args = {}, const pybind11::kwargs & kwargs = {});
70+
void call_task_in_next_spin(pybind11::handle task);
7071
pybind11::object create_future();
7172
bool shutdown(std::optional<double> timeout_sec = {});
7273
bool add_node(pybind11::object node);
@@ -149,11 +150,6 @@ class EventsExecutor
149150
/// create_task() implementation for details.
150151
void IterateTask(pybind11::handle task);
151152

152-
/// Posts a call to IterateTask() for every outstanding entry in tasks_; should be invoked from
153-
/// other Handle*Ready() methods to check if any asynchronous Tasks have been unblocked by the
154-
/// newly-handled event.
155-
void PostOutstandingTasks();
156-
157153
void HandleCallbackExceptionInNodeEntity(
158154
const pybind11::error_already_set &, pybind11::handle entity,
159155
const std::string & node_entity_attr);
@@ -190,9 +186,6 @@ class EventsExecutor
190186
pybind11::set services_;
191187
pybind11::set waitables_;
192188

193-
/// Collection of asynchronous Tasks awaiting new events to further iterate.
194-
std::vector<pybind11::handle> blocked_tasks_;
195-
196189
/// Cache for rcl pointers underlying each waitables_ entry, because those are harder to retrieve
197190
/// than the other entity types.
198191
std::unordered_map<pybind11::handle, WaitableSubEntities, PythonHasher,

0 commit comments

Comments
 (0)