diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 3e099d8fa..5d45b431d 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack +from dataclasses import dataclass import inspect import os from threading import Condition @@ -25,6 +27,8 @@ from typing import Callable from typing import ContextManager from typing import Coroutine +from typing import Deque +from typing import Dict from typing import Generator from typing import List from typing import Optional @@ -150,6 +154,12 @@ def timeout(self, timeout): self._timeout = timeout +@dataclass +class TaskData: + source_node: 'Optional[Node]' = None + source_entity: 'Optional[WaitableEntityType]' = None + + class Executor(ContextManager['Executor']): """ The base class for an executor. @@ -179,8 +189,10 @@ def __init__(self, *, context: Optional[Context] = None) -> None: self._context = get_default_context() if context is None else context self._nodes: Set[Node] = set() self._nodes_lock = RLock() - # Tasks to be executed (oldest first) 3-tuple Task, Entity, Node - self._tasks: List[Tuple[Task, Optional[WaitableEntityType], Optional[Node]]] = [] + # all tasks that are not complete or canceled + self._pending_tasks: Dict[Task, TaskData] = {} + # tasks that are ready to execute + self._ready_tasks: Deque[Task] = deque() self._tasks_lock = Lock() # This is triggered when wait_for_ready_callbacks should rebuild the wait list self._guard = GuardCondition( @@ -214,11 +226,21 @@ def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> """ task = Task(callback, args, kwargs, executor=self) with self._tasks_lock: - self._tasks.append((task, None, None)) - self._guard.trigger() - # Task inherits from Future + self._pending_tasks[task] = TaskData() + self._call_task_in_next_spin(task) return task + def _call_task_in_next_spin(self, task: Task) -> None: + """ + Add a task to the executor to be executed in the next spin. + + :param task: A task to be run in the executor. + """ + with self._tasks_lock: + self._ready_tasks.append(task) + if self._guard: + self._guard.trigger() + def create_future(self) -> Future: """Create a Future object attached to the Executor.""" return Future(executor=self) @@ -544,7 +566,10 @@ async def handler(entity, gc, is_shutdown, work_tracker): handler, (entity, self._guard, self._is_shutdown, self._work_tracker), executor=self) with self._tasks_lock: - self._tasks.append((task, entity, node)) + self._pending_tasks[task] = TaskData( + source_entity=entity, + source_node=node + ) return task def can_execute(self, entity: WaitableEntityType) -> bool: @@ -588,21 +613,25 @@ def _wait_for_ready_callbacks( nodes_to_use = self.get_nodes() # Yield tasks in-progress before waiting for new work - tasks = None with self._tasks_lock: - tasks = list(self._tasks) - if tasks: - for task, entity, node in reversed(tasks): - if (not task.executing() and not task.done() and - (node is None or node in nodes_to_use)): - yielded_work = True - yield task, entity, node - with self._tasks_lock: - # Get rid of any tasks that are done - self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) - # Get rid of any tasks that are cancelled - self._tasks = list(filter(lambda t_e_n: not t_e_n[0].cancelled(), self._tasks)) - + # Get rid of any tasks that are done or cancelled + for task in list(self._pending_tasks.keys()): + if task.done() or task.cancelled(): + del self._pending_tasks[task] + + ready_tasks_count = len(self._ready_tasks) + for _ in range(ready_tasks_count): + task = self._ready_tasks.popleft() + task_data = self._pending_tasks[task] + node = task_data.source_node + if node is None or node in nodes_to_use: + entity = task_data.source_entity + yielded_work = True + yield task, entity, node + else: + # Asked not to execute these tasks, so don't do them yet + with self._tasks_lock: + self._ready_tasks.append(task) # Gather entities that can be waited on subscriptions: List[Subscription] = [] guards: List[GuardCondition] = [] diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 9fd1504b7..1686da0e2 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -16,7 +16,7 @@ import inspect import sys import threading -from typing import Callable +from typing import Any, Callable, Coroutine, Generator, Optional import warnings import weakref @@ -58,10 +58,13 @@ def __del__(self): 'The following exception was never retrieved: ' + str(self._exception), file=sys.stderr) - def __await__(self): + def __await__(self) -> Generator['Future', None, Optional[Any]]: # Yield if the task is not finished - while self._pending(): - yield + if self._pending(): + # This tells the task to suspend until the future is done + yield self + if self._pending(): + raise RuntimeError('Future awaited a second time before it was done') return self.result() def _pending(self) -> bool: @@ -264,16 +267,7 @@ def __call__(self): self._executing = True if inspect.iscoroutine(self._handler): - # Execute a coroutine - try: - self._handler.send(None) - except StopIteration as e: - # The coroutine finished; store the result - self.set_result(e.value) - self._complete_task() - except Exception as e: - self.set_exception(e) - self._complete_task() + self._execute_coroutine_step(self._handler) else: # Execute a normal function try: @@ -286,7 +280,48 @@ def __call__(self): finally: self._task_lock.release() - def _complete_task(self): + def _execute_coroutine_step(self, coro: Coroutine) -> None: + """Execute or resume a coroutine task.""" + try: + result = coro.send(None) + except StopIteration as e: + # The coroutine finished; store the result + self.set_result(e.value) + self._complete_task() + except Exception as e: + # The coroutine raised; store the exception + self.set_exception(e) + self._complete_task() + else: + # The coroutine yielded; suspend the task until it is resumed + executor = self._executor() + if executor is None: + raise RuntimeError( + 'Task tried to reschedule but no executor was set: ' + 'tasks should only be initialized through executor.create_task()') + elif isinstance(result, Future): + # Schedule the task to resume when the future is done + self._add_resume_callback(result, executor) + elif result is None: + # The coroutine yielded None, schedule the task to resume in the next spin + executor._call_task_in_next_spin(self) + else: + raise TypeError( + f'Expected coroutine to yield a Future or None, got: {type(result)}') + + def _add_resume_callback(self, future: Future, executor) -> None: + future_executor = future._executor() + if future_executor is None: + # The future is not associated with an executor yet, so associate it with ours + future._set_executor(executor) + elif future_executor is not executor: + raise RuntimeError('A task can only await futures associated with the same executor') + + # The future is associated with the same executor, so we can resume the task directly + # in the done callback + future.add_done_callback(lambda _: self.__call__()) + + def _complete_task(self) -> None: """Cleanup after task finished.""" self._handler = None self._args = None diff --git a/rclpy/src/rclpy/events_executor/events_executor.cpp b/rclpy/src/rclpy/events_executor/events_executor.cpp index 99a4e911c..36fdb1eec 100644 --- a/rclpy/src/rclpy/events_executor/events_executor.cpp +++ b/rclpy/src/rclpy/events_executor/events_executor.cpp @@ -74,10 +74,15 @@ pybind11::object EventsExecutor::create_task( // manual refcounting on it instead. py::handle cb_task_handle = task; cb_task_handle.inc_ref(); - events_queue_.Enqueue(std::bind(&EventsExecutor::IterateTask, this, cb_task_handle)); + call_task_in_next_spin(task); return task; } +void EventsExecutor::call_task_in_next_spin(pybind11::handle task) +{ + events_queue_.Enqueue(std::bind(&EventsExecutor::IterateTask, this, task)); +} + pybind11::object EventsExecutor::create_future() { using py::literals::operator""_a; @@ -163,8 +168,6 @@ void EventsExecutor::spin(std::optional timeout_sec, bool stop_after_use throw std::runtime_error("Attempt to spin an already-spinning Executor"); } stop_after_user_callback_ = stop_after_user_callback; - // Any blocked tasks may have become unblocked while we weren't looking. - PostOutstandingTasks(); // Release the GIL while we block. Any callbacks on the events queue that want to touch Python // will need to reacquire it though. py::gil_scoped_release gil_release; @@ -346,8 +349,6 @@ void EventsExecutor::HandleSubscriptionReady(py::handle subscription, size_t num got_none = true; } } - - PostOutstandingTasks(); } void EventsExecutor::HandleAddedTimer(py::handle timer) {timers_manager_.AddTimer(timer);} @@ -374,7 +375,6 @@ void EventsExecutor::HandleTimerReady(py::handle timer) } else if (stop_after_user_callback_) { events_queue_.Stop(); } - PostOutstandingTasks(); } void EventsExecutor::HandleAddedClient(py::handle client) @@ -445,8 +445,6 @@ void EventsExecutor::HandleClientReady(py::handle client, size_t number_of_event } } } - - PostOutstandingTasks(); } void EventsExecutor::HandleAddedService(py::handle service) @@ -509,8 +507,6 @@ void EventsExecutor::HandleServiceReady(py::handle service, size_t number_of_eve send_response(response, header); } } - - PostOutstandingTasks(); } void EventsExecutor::HandleAddedWaitable(py::handle waitable) @@ -776,8 +772,6 @@ void EventsExecutor::HandleWaitableReady( // execute() is an async method, we need a Task to run it create_task(execute(data)); } - - PostOutstandingTasks(); } void EventsExecutor::IterateTask(py::handle task) @@ -810,26 +804,7 @@ void EventsExecutor::IterateTask(py::handle task) throw; } } - } else { - // Task needs more iteration. Store the handle and revisit it later after the next ready - // entity which may unblock it. - // TODO(bmartin427) This matches the behavior of SingleThreadedExecutor and avoids busy - // looping, but I don't love it because if the task is waiting on something other than an rcl - // entity (e.g. an asyncio sleep, or a Future triggered from another thread, or even another - // Task), there can be arbitrarily long latency before some rcl activity causes us to go - // revisit that Task. - blocked_tasks_.push_back(task); - } -} - -void EventsExecutor::PostOutstandingTasks() -{ - for (auto & task : blocked_tasks_) { - events_queue_.Enqueue(std::bind(&EventsExecutor::IterateTask, this, task)); } - // Clear the entire outstanding tasks list. Any tasks that need further iteration will re-add - // themselves during IterateTask(). - blocked_tasks_.clear(); } void EventsExecutor::HandleCallbackExceptionInNodeEntity( @@ -888,6 +863,7 @@ void define_events_executor(py::object module) .def(py::init(), py::arg("context")) .def_property_readonly("context", &EventsExecutor::get_context) .def("create_task", &EventsExecutor::create_task, py::arg("callback")) + .def("_call_task_in_next_spin", &EventsExecutor::call_task_in_next_spin, py::arg("task")) .def("create_future", &EventsExecutor::create_future) .def("shutdown", &EventsExecutor::shutdown, py::arg("timeout_sec") = py::none()) .def("add_node", &EventsExecutor::add_node, py::arg("node")) diff --git a/rclpy/src/rclpy/events_executor/events_executor.hpp b/rclpy/src/rclpy/events_executor/events_executor.hpp index 7d67d20ad..6cd28845e 100644 --- a/rclpy/src/rclpy/events_executor/events_executor.hpp +++ b/rclpy/src/rclpy/events_executor/events_executor.hpp @@ -66,6 +66,7 @@ class EventsExecutor pybind11::object get_context() const {return rclpy_context_;} pybind11::object create_task( pybind11::object callback, pybind11::args args = {}, const pybind11::kwargs & kwargs = {}); + void call_task_in_next_spin(pybind11::handle task); pybind11::object create_future(); bool shutdown(std::optional timeout_sec = {}); bool add_node(pybind11::object node); @@ -148,11 +149,6 @@ class EventsExecutor /// create_task() implementation for details. void IterateTask(pybind11::handle task); - /// Posts a call to IterateTask() for every outstanding entry in tasks_; should be invoked from - /// other Handle*Ready() methods to check if any asynchronous Tasks have been unblocked by the - /// newly-handled event. - void PostOutstandingTasks(); - void HandleCallbackExceptionInNodeEntity( const pybind11::error_already_set &, pybind11::handle entity, const std::string & node_entity_attr); @@ -188,9 +184,6 @@ class EventsExecutor pybind11::set services_; pybind11::set waitables_; - /// Collection of asynchronous Tasks awaiting new events to further iterate. - std::vector blocked_tasks_; - /// Cache for rcl pointers underlying each waitables_ entry, because those are harder to retrieve /// than the other entity types. std::unordered_map waitable_entities_; diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 1c7afb66c..b23e95858 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -277,6 +277,40 @@ async def coroutine(): self.assertTrue(future.done()) self.assertEqual('Sentinel Result', future.result()) + def test_create_task_coroutine_yield(self) -> None: + self.assertIsNotNone(self.node.handle) + for cls in [SingleThreadedExecutor, EventsExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + executor.add_node(self.node) + + called1 = False + called2 = False + + async def coroutine() -> str: + nonlocal called1 + nonlocal called2 + called1 = True + await asyncio.sleep(0) + called2 = True + return 'Sentinel Result' + + future = executor.create_task(coroutine) + self.assertFalse(future.done()) + self.assertFalse(called1) + self.assertFalse(called2) + + executor.spin_once(timeout_sec=0) + self.assertFalse(future.done()) + self.assertTrue(called1) + self.assertFalse(called2) + + executor.spin_once(timeout_sec=1) + self.assertTrue(future.done()) + self.assertTrue(called1) + self.assertTrue(called2) + self.assertEqual('Sentinel Result', future.result()) + def test_create_task_coroutine_cancel(self) -> None: self.assertIsNotNone(self.node.handle) for cls in [SingleThreadedExecutor, EventsExecutor]: @@ -299,7 +333,39 @@ async def coroutine(): self.assertTrue(future.cancelled()) self.assertEqual(None, future.result()) - def test_create_task_normal_function(self): + def test_create_task_coroutine_wake_from_another_thread(self) -> None: + self.assertIsNotNone(self.node.handle) + + for cls in [SingleThreadedExecutor, MultiThreadedExecutor, EventsExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + thread_future = executor.create_future() + + async def coroutine(): + await thread_future + + def future_thread(): + time.sleep(0.1) # Simulate some work + thread_future.set_result(None) + + t = threading.Thread(target=future_thread) + + coroutine_future = executor.create_task(coroutine) + + start_time = time.perf_counter() + + t.start() + executor.spin_until_future_complete(coroutine_future, timeout_sec=1.0) + + end_time = time.perf_counter() + + self.assertTrue(coroutine_future.done()) + + # The coroutine should take at least 0.1 seconds to complete because it waits for + # the thread to set the future but nowhere near the 1 second timeout + assert 0.1 <= end_time - start_time < 0.2 + + def test_create_task_normal_function(self) -> None: self.assertIsNotNone(self.node.handle) for cls in [SingleThreadedExecutor, EventsExecutor]: with self.subTest(cls=cls): @@ -328,18 +394,12 @@ async def coro1(): await future2 return 'Sentinel Result 1' + future1 = executor.create_task(coro1) + async def coro2(): return 'Sentinel Result 2' - # We need to swap the order of the coroutines depending on the executor type - # This is nessessary because https://github.com/ros2/rclpy/pull/1304 - # won't be backported to jazzy - if cls is SingleThreadedExecutor: - future2 = executor.create_task(coro2) - future1 = executor.create_task(coro1) - else: - future1 = executor.create_task(coro1) - future2 = executor.create_task(coro2) + future2 = executor.create_task(coro2) # Coro1 is the 1st task, so it gets to await future2 in this spin executor.spin_once(timeout_sec=0) diff --git a/rclpy/test/test_task.py b/rclpy/test/test_task.py index d94a74764..a8bbc6f0f 100644 --- a/rclpy/test/test_task.py +++ b/rclpy/test/test_task.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import unittest from rclpy.task import Future @@ -50,31 +49,7 @@ def func(): self.assertTrue(t.done()) self.assertEqual('Sentinel Result', t.result()) - def test_coroutine(self): - called1 = False - called2 = False - - async def coro(): - nonlocal called1 - nonlocal called2 - called1 = True - await asyncio.sleep(0) - called2 = True - return 'Sentinel Result' - - t = Task(coro) - t() - self.assertTrue(called1) - self.assertFalse(called2) - - called1 = False - t() - self.assertFalse(called1) - self.assertTrue(called2) - self.assertTrue(t.done()) - self.assertEqual('Sentinel Result', t.result()) - - def test_done_callback_scheduled(self): + def test_done_callback_scheduled(self) -> None: executor = DummyExecutor() t = Task(lambda: None, executor=executor)