diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 581e1546c..c0ee37328 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack import inspect @@ -23,6 +24,8 @@ from typing import Any from typing import Callable from typing import Coroutine +from typing import Deque +from typing import Dict from typing import Generator from typing import List from typing import Optional @@ -146,6 +149,16 @@ def timeout(self, timeout): self._timeout = timeout +class TaskData: + def __init__( + self, + source_node: 'Optional[Node]' = None, + source_entity: 'Optional[WaitableEntityType]' = None, + ): + self.source_node = source_node + self.source_entity = source_entity + + class Executor: """ The base class for an executor. @@ -165,8 +178,10 @@ def __init__(self, *, context: Context = None) -> None: self._context = get_default_context() if context is None else context self._nodes: Set[Node] = set() self._nodes_lock = RLock() - # Tasks to be executed (oldest first) 3-tuple Task, Entity, Node - self._tasks: List[Tuple[Task, Optional[WaitableEntityType], Optional[Node]]] = [] + # all tasks that are not complete or canceled + self._pending_tasks: Dict[Task, TaskData] = {} + # tasks that are ready to execute + self._ready_tasks: Deque[Task] = deque() self._tasks_lock = Lock() # This is triggered when wait_for_ready_callbacks should rebuild the wait list self._guard = GuardCondition( @@ -200,11 +215,21 @@ def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> """ task = Task(callback, args, kwargs, executor=self) with self._tasks_lock: - self._tasks.append((task, None, None)) - self._guard.trigger() - # Task inherits from Future + self._pending_tasks[task] = TaskData() + self._call_task_in_next_spin(task) return task + def _call_task_in_next_spin(self, task: Task) -> None: + """ + Add a task to the executor to be executed in the next spin. + + :param task: A task to be run in the executor. + """ + with self._tasks_lock: + self._ready_tasks.append(task) + if self._guard: + self._guard.trigger() + def shutdown(self, timeout_sec: float = None) -> bool: """ Stop executing callbacks and wait for their completion. @@ -473,7 +498,10 @@ async def handler(entity, gc, is_shutdown, work_tracker): handler, (entity, self._guard, self._is_shutdown, self._work_tracker), executor=self) with self._tasks_lock: - self._tasks.append((task, entity, node)) + self._pending_tasks[task] = TaskData( + source_entity=entity, + source_node=node + ) return task def can_execute(self, entity: WaitableEntityType) -> bool: @@ -517,21 +545,25 @@ def _wait_for_ready_callbacks( nodes_to_use = self.get_nodes() # Yield tasks in-progress before waiting for new work - tasks = None with self._tasks_lock: - tasks = list(self._tasks) - if tasks: - for task, entity, node in reversed(tasks): - if (not task.executing() and not task.done() and - (node is None or node in nodes_to_use)): - yielded_work = True - yield task, entity, node - with self._tasks_lock: - # Get rid of any tasks that are done - self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) - # Get rid of any tasks that are cancelled - self._tasks = list(filter(lambda t_e_n: not t_e_n[0].cancelled(), self._tasks)) - + # Get rid of any tasks that are done or cancelled + for task in list(self._pending_tasks.keys()): + if task.done() or task.cancelled(): + del self._pending_tasks[task] + + ready_tasks_count = len(self._ready_tasks) + for _ in range(ready_tasks_count): + task = self._ready_tasks.popleft() + task_data = self._pending_tasks[task] + node = task_data.source_node + if node is None or node in nodes_to_use: + entity = task_data.source_entity + yielded_work = True + yield task, entity, node + else: + # Asked not to execute these tasks, so don't do them yet + with self._tasks_lock: + self._ready_tasks.append(task) # Gather entities that can be waited on subscriptions: List[Subscription] = [] guards: List[GuardCondition] = [] diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index e6da94752..4ef811b4e 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -16,6 +16,7 @@ import inspect import sys import threading +from typing import Any, Coroutine, Generator, Optional import warnings import weakref @@ -57,10 +58,13 @@ def __del__(self): 'The following exception was never retrieved: ' + str(self._exception), file=sys.stderr) - def __await__(self): + def __await__(self) -> Generator['Future', None, Optional[Any]]: # Yield if the task is not finished - while self._pending(): - yield + if self._pending(): + # This tells the task to suspend until the future is done + yield self + if self._pending(): + raise RuntimeError('Future awaited a second time before it was done') return self.result() def _pending(self): @@ -249,16 +253,7 @@ def __call__(self): self._executing = True if inspect.iscoroutine(self._handler): - # Execute a coroutine - try: - self._handler.send(None) - except StopIteration as e: - # The coroutine finished; store the result - self.set_result(e.value) - self._complete_task() - except Exception as e: - self.set_exception(e) - self._complete_task() + self._execute_coroutine_step(self._handler) else: # Execute a normal function try: @@ -271,7 +266,48 @@ def __call__(self): finally: self._task_lock.release() - def _complete_task(self): + def _execute_coroutine_step(self, coro: Coroutine) -> None: + """Execute or resume a coroutine task.""" + try: + result = coro.send(None) + except StopIteration as e: + # The coroutine finished; store the result + self.set_result(e.value) + self._complete_task() + except Exception as e: + # The coroutine raised; store the exception + self.set_exception(e) + self._complete_task() + else: + # The coroutine yielded; suspend the task until it is resumed + executor = self._executor() + if executor is None: + raise RuntimeError( + 'Task tried to reschedule but no executor was set: ' + 'tasks should only be initialized through executor.create_task()') + elif isinstance(result, Future): + # Schedule the task to resume when the future is done + self._add_resume_callback(result, executor) + elif result is None: + # The coroutine yielded None, schedule the task to resume in the next spin + executor._call_task_in_next_spin(self) + else: + raise TypeError( + f'Expected coroutine to yield a Future or None, got: {type(result)}') + + def _add_resume_callback(self, future: Future, executor) -> None: + future_executor = future._executor() + if future_executor is None: + # The future is not associated with an executor yet, so associate it with ours + future._set_executor(executor) + elif future_executor is not executor: + raise RuntimeError('A task can only await futures associated with the same executor') + + # The future is associated with the same executor, so we can resume the task directly + # in the done callback + future.add_done_callback(lambda _: self.__call__()) + + def _complete_task(self) -> None: """Cleanup after task finished.""" self._handler = None self._args = None diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index df873c595..1dd1b2b2c 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -225,6 +225,39 @@ async def coroutine(): self.assertTrue(future.done()) self.assertEqual('Sentinel Result', future.result()) + def test_create_task_coroutine_yield(self) -> None: + self.assertIsNotNone(self.node.handle) + + executor = SingleThreadedExecutor(context=self.context) + executor.add_node(self.node) + + called1 = False + called2 = False + + async def coroutine() -> str: + nonlocal called1 + nonlocal called2 + called1 = True + await asyncio.sleep(0) + called2 = True + return 'Sentinel Result' + + future = executor.create_task(coroutine) + self.assertFalse(future.done()) + self.assertFalse(called1) + self.assertFalse(called2) + + executor.spin_once(timeout_sec=0) + self.assertFalse(future.done()) + self.assertTrue(called1) + self.assertFalse(called2) + + executor.spin_once(timeout_sec=1) + self.assertTrue(future.done()) + self.assertTrue(called1) + self.assertTrue(called2) + self.assertEqual('Sentinel Result', future.result()) + def test_create_task_coroutine_cancel(self) -> None: self.assertIsNotNone(self.node.handle) executor = SingleThreadedExecutor(context=self.context) @@ -245,7 +278,39 @@ async def coroutine(): self.assertTrue(future.cancelled()) self.assertEqual(None, future.result()) - def test_create_task_normal_function(self): + def test_create_task_coroutine_wake_from_another_thread(self) -> None: + self.assertIsNotNone(self.node.handle) + + for cls in [SingleThreadedExecutor, MultiThreadedExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + thread_future = Future(executor=executor) + + async def coroutine(): + await thread_future + + def future_thread(): + time.sleep(0.1) # Simulate some work + thread_future.set_result(None) + + t = threading.Thread(target=future_thread) + + coroutine_future = executor.create_task(coroutine) + + start_time = time.perf_counter() + + t.start() + executor.spin_until_future_complete(coroutine_future, timeout_sec=1.0) + + end_time = time.perf_counter() + + self.assertTrue(coroutine_future.done()) + + # The coroutine should take at least 0.1 seconds to complete because it waits for + # the thread to set the future but nowhere near the 1 second timeout + assert 0.1 <= end_time - start_time < 0.2 + + def test_create_task_normal_function(self) -> None: self.assertIsNotNone(self.node.handle) executor = SingleThreadedExecutor(context=self.context) executor.add_node(self.node) @@ -266,30 +331,30 @@ def test_create_task_dependent_coroutines(self): executor.add_node(self.node) async def coro1(): + nonlocal future2 + await future2 return 'Sentinel Result 1' future1 = executor.create_task(coro1) async def coro2(): - nonlocal future1 - await future1 return 'Sentinel Result 2' future2 = executor.create_task(coro2) - # Coro2 is newest task, so it gets to await future1 in this spin + # Coro1 is the 1st task, so it gets to await future2 in this spin executor.spin_once(timeout_sec=0) - # Coro1 execs in this spin + # Coro2 execs in this spin executor.spin_once(timeout_sec=0) - self.assertTrue(future1.done()) - self.assertEqual('Sentinel Result 1', future1.result()) - self.assertFalse(future2.done()) - - # Coro2 passes the await step here (timeout change forces new generator) - executor.spin_once(timeout_sec=1) + self.assertFalse(future1.done()) self.assertTrue(future2.done()) self.assertEqual('Sentinel Result 2', future2.result()) + # Coro1 passes the await step here (timeout change forces new generator) + executor.spin_once(timeout_sec=1) + self.assertTrue(future1.done()) + self.assertEqual('Sentinel Result 1', future1.result()) + def test_create_task_during_spin(self): self.assertIsNotNone(self.node.handle) executor = SingleThreadedExecutor(context=self.context) diff --git a/rclpy/test/test_task.py b/rclpy/test/test_task.py index d94a74764..a8bbc6f0f 100644 --- a/rclpy/test/test_task.py +++ b/rclpy/test/test_task.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import unittest from rclpy.task import Future @@ -50,31 +49,7 @@ def func(): self.assertTrue(t.done()) self.assertEqual('Sentinel Result', t.result()) - def test_coroutine(self): - called1 = False - called2 = False - - async def coro(): - nonlocal called1 - nonlocal called2 - called1 = True - await asyncio.sleep(0) - called2 = True - return 'Sentinel Result' - - t = Task(coro) - t() - self.assertTrue(called1) - self.assertFalse(called2) - - called1 = False - t() - self.assertFalse(called1) - self.assertTrue(called2) - self.assertTrue(t.done()) - self.assertEqual('Sentinel Result', t.result()) - - def test_done_callback_scheduled(self): + def test_done_callback_scheduled(self) -> None: executor = DummyExecutor() t = Task(lambda: None, executor=executor)