Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions rclpy/rclpy/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
import inspect
Expand All @@ -23,6 +24,8 @@
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Deque
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
Expand Down Expand Up @@ -146,6 +149,16 @@ def timeout(self, timeout):
self._timeout = timeout


class TaskData:
def __init__(
self,
source_node: 'Optional[Node]' = None,
source_entity: 'Optional[WaitableEntityType]' = None,
):
self.source_node = source_node
self.source_entity = source_entity


class Executor:
"""
The base class for an executor.
Expand All @@ -165,8 +178,10 @@ def __init__(self, *, context: Context = None) -> None:
self._context = get_default_context() if context is None else context
self._nodes: Set[Node] = set()
self._nodes_lock = RLock()
# Tasks to be executed (oldest first) 3-tuple Task, Entity, Node
self._tasks: List[Tuple[Task, Optional[WaitableEntityType], Optional[Node]]] = []
# all tasks that are not complete or canceled
self._pending_tasks: Dict[Task, TaskData] = {}
# tasks that are ready to execute
self._ready_tasks: Deque[Task] = deque()
self._tasks_lock = Lock()
# This is triggered when wait_for_ready_callbacks should rebuild the wait list
self._guard = GuardCondition(
Expand Down Expand Up @@ -200,11 +215,21 @@ def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) ->
"""
task = Task(callback, args, kwargs, executor=self)
with self._tasks_lock:
self._tasks.append((task, None, None))
self._guard.trigger()
# Task inherits from Future
self._pending_tasks[task] = TaskData()
self._call_task_in_next_spin(task)
return task

def _call_task_in_next_spin(self, task: Task) -> None:
"""
Add a task to the executor to be executed in the next spin.

:param task: A task to be run in the executor.
"""
with self._tasks_lock:
self._ready_tasks.append(task)
if self._guard:
self._guard.trigger()

def shutdown(self, timeout_sec: float = None) -> bool:
"""
Stop executing callbacks and wait for their completion.
Expand Down Expand Up @@ -473,7 +498,10 @@ async def handler(entity, gc, is_shutdown, work_tracker):
handler, (entity, self._guard, self._is_shutdown, self._work_tracker),
executor=self)
with self._tasks_lock:
self._tasks.append((task, entity, node))
self._pending_tasks[task] = TaskData(
source_entity=entity,
source_node=node
)
return task

def can_execute(self, entity: WaitableEntityType) -> bool:
Expand Down Expand Up @@ -517,21 +545,25 @@ def _wait_for_ready_callbacks(
nodes_to_use = self.get_nodes()

# Yield tasks in-progress before waiting for new work
tasks = None
with self._tasks_lock:
tasks = list(self._tasks)
if tasks:
for task, entity, node in reversed(tasks):
if (not task.executing() and not task.done() and
(node is None or node in nodes_to_use)):
yielded_work = True
yield task, entity, node
with self._tasks_lock:
# Get rid of any tasks that are done
self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks))
# Get rid of any tasks that are cancelled
self._tasks = list(filter(lambda t_e_n: not t_e_n[0].cancelled(), self._tasks))

# Get rid of any tasks that are done or cancelled
for task in list(self._pending_tasks.keys()):
if task.done() or task.cancelled():
del self._pending_tasks[task]

ready_tasks_count = len(self._ready_tasks)
for _ in range(ready_tasks_count):
task = self._ready_tasks.popleft()
task_data = self._pending_tasks[task]
node = task_data.source_node
if node is None or node in nodes_to_use:
entity = task_data.source_entity
yielded_work = True
yield task, entity, node
else:
# Asked not to execute these tasks, so don't do them yet
with self._tasks_lock:
self._ready_tasks.append(task)
# Gather entities that can be waited on
subscriptions: List[Subscription] = []
guards: List[GuardCondition] = []
Expand Down
64 changes: 50 additions & 14 deletions rclpy/rclpy/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
import sys
import threading
from typing import Any, Coroutine, Generator, Optional
import warnings
import weakref

Expand Down Expand Up @@ -57,10 +58,13 @@ def __del__(self):
'The following exception was never retrieved: ' + str(self._exception),
file=sys.stderr)

def __await__(self):
def __await__(self) -> Generator['Future', None, Optional[Any]]:
# Yield if the task is not finished
while self._pending():
yield
if self._pending():
# This tells the task to suspend until the future is done
yield self
if self._pending():
raise RuntimeError('Future awaited a second time before it was done')
return self.result()

def _pending(self):
Expand Down Expand Up @@ -249,16 +253,7 @@ def __call__(self):
self._executing = True

if inspect.iscoroutine(self._handler):
# Execute a coroutine
try:
self._handler.send(None)
except StopIteration as e:
# The coroutine finished; store the result
self.set_result(e.value)
self._complete_task()
except Exception as e:
self.set_exception(e)
self._complete_task()
self._execute_coroutine_step(self._handler)
else:
# Execute a normal function
try:
Expand All @@ -271,7 +266,48 @@ def __call__(self):
finally:
self._task_lock.release()

def _complete_task(self):
def _execute_coroutine_step(self, coro: Coroutine) -> None:
"""Execute or resume a coroutine task."""
try:
result = coro.send(None)
except StopIteration as e:
# The coroutine finished; store the result
self.set_result(e.value)
self._complete_task()
except Exception as e:
# The coroutine raised; store the exception
self.set_exception(e)
self._complete_task()
else:
# The coroutine yielded; suspend the task until it is resumed
executor = self._executor()
if executor is None:
raise RuntimeError(
'Task tried to reschedule but no executor was set: '
'tasks should only be initialized through executor.create_task()')
elif isinstance(result, Future):
# Schedule the task to resume when the future is done
self._add_resume_callback(result, executor)
elif result is None:
# The coroutine yielded None, schedule the task to resume in the next spin
executor._call_task_in_next_spin(self)
else:
raise TypeError(
f'Expected coroutine to yield a Future or None, got: {type(result)}')

def _add_resume_callback(self, future: Future, executor) -> None:
future_executor = future._executor()
if future_executor is None:
# The future is not associated with an executor yet, so associate it with ours
future._set_executor(executor)
elif future_executor is not executor:
raise RuntimeError('A task can only await futures associated with the same executor')

# The future is associated with the same executor, so we can resume the task directly
# in the done callback
future.add_done_callback(lambda _: self.__call__())

def _complete_task(self) -> None:
"""Cleanup after task finished."""
self._handler = None
self._args = None
Expand Down
87 changes: 76 additions & 11 deletions rclpy/test/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,39 @@ async def coroutine():
self.assertTrue(future.done())
self.assertEqual('Sentinel Result', future.result())

def test_create_task_coroutine_yield(self) -> None:
self.assertIsNotNone(self.node.handle)

executor = SingleThreadedExecutor(context=self.context)
executor.add_node(self.node)

called1 = False
called2 = False

async def coroutine() -> str:
nonlocal called1
nonlocal called2
called1 = True
await asyncio.sleep(0)
called2 = True
return 'Sentinel Result'

future = executor.create_task(coroutine)
self.assertFalse(future.done())
self.assertFalse(called1)
self.assertFalse(called2)

executor.spin_once(timeout_sec=0)
self.assertFalse(future.done())
self.assertTrue(called1)
self.assertFalse(called2)

executor.spin_once(timeout_sec=1)
self.assertTrue(future.done())
self.assertTrue(called1)
self.assertTrue(called2)
self.assertEqual('Sentinel Result', future.result())

def test_create_task_coroutine_cancel(self) -> None:
self.assertIsNotNone(self.node.handle)
executor = SingleThreadedExecutor(context=self.context)
Expand All @@ -245,7 +278,39 @@ async def coroutine():
self.assertTrue(future.cancelled())
self.assertEqual(None, future.result())

def test_create_task_normal_function(self):
def test_create_task_coroutine_wake_from_another_thread(self) -> None:
self.assertIsNotNone(self.node.handle)

for cls in [SingleThreadedExecutor, MultiThreadedExecutor]:
with self.subTest(cls=cls):
executor = cls(context=self.context)
thread_future = Future(executor=executor)

async def coroutine():
await thread_future

def future_thread():
time.sleep(0.1) # Simulate some work
thread_future.set_result(None)

t = threading.Thread(target=future_thread)

coroutine_future = executor.create_task(coroutine)

start_time = time.perf_counter()

t.start()
executor.spin_until_future_complete(coroutine_future, timeout_sec=1.0)

end_time = time.perf_counter()

self.assertTrue(coroutine_future.done())

# The coroutine should take at least 0.1 seconds to complete because it waits for
# the thread to set the future but nowhere near the 1 second timeout
assert 0.1 <= end_time - start_time < 0.2

def test_create_task_normal_function(self) -> None:
self.assertIsNotNone(self.node.handle)
executor = SingleThreadedExecutor(context=self.context)
executor.add_node(self.node)
Expand All @@ -266,30 +331,30 @@ def test_create_task_dependent_coroutines(self):
executor.add_node(self.node)

async def coro1():
nonlocal future2
await future2
return 'Sentinel Result 1'

future1 = executor.create_task(coro1)

async def coro2():
nonlocal future1
await future1
return 'Sentinel Result 2'

future2 = executor.create_task(coro2)

# Coro2 is newest task, so it gets to await future1 in this spin
# Coro1 is the 1st task, so it gets to await future2 in this spin
executor.spin_once(timeout_sec=0)
# Coro1 execs in this spin
# Coro2 execs in this spin
executor.spin_once(timeout_sec=0)
self.assertTrue(future1.done())
self.assertEqual('Sentinel Result 1', future1.result())
self.assertFalse(future2.done())

# Coro2 passes the await step here (timeout change forces new generator)
executor.spin_once(timeout_sec=1)
self.assertFalse(future1.done())
self.assertTrue(future2.done())
self.assertEqual('Sentinel Result 2', future2.result())

# Coro1 passes the await step here (timeout change forces new generator)
executor.spin_once(timeout_sec=1)
self.assertTrue(future1.done())
self.assertEqual('Sentinel Result 1', future1.result())

def test_create_task_during_spin(self):
self.assertIsNotNone(self.node.handle)
executor = SingleThreadedExecutor(context=self.context)
Expand Down
27 changes: 1 addition & 26 deletions rclpy/test/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import unittest

from rclpy.task import Future
Expand Down Expand Up @@ -50,31 +49,7 @@ def func():
self.assertTrue(t.done())
self.assertEqual('Sentinel Result', t.result())

def test_coroutine(self):
called1 = False
called2 = False

async def coro():
nonlocal called1
nonlocal called2
called1 = True
await asyncio.sleep(0)
called2 = True
return 'Sentinel Result'

t = Task(coro)
t()
self.assertTrue(called1)
self.assertFalse(called2)

called1 = False
t()
self.assertFalse(called1)
self.assertTrue(called2)
self.assertTrue(t.done())
self.assertEqual('Sentinel Result', t.result())

def test_done_callback_scheduled(self):
def test_done_callback_scheduled(self) -> None:
executor = DummyExecutor()

t = Task(lambda: None, executor=executor)
Expand Down