diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 2d4295d7c..fd097a4f2 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -219,6 +219,10 @@ def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> # Task inherits from Future return task + def create_future(self) -> Future: + """Create a Future object attached to the Executor.""" + return Future(executor=self) + def shutdown(self, timeout_sec: Optional[float] = None) -> bool: """ Stop executing callbacks and wait for their completion. diff --git a/rclpy/src/rclpy/events_executor/events_executor.cpp b/rclpy/src/rclpy/events_executor/events_executor.cpp index 7e3dd74ff..526d2cad8 100644 --- a/rclpy/src/rclpy/events_executor/events_executor.cpp +++ b/rclpy/src/rclpy/events_executor/events_executor.cpp @@ -52,6 +52,7 @@ EventsExecutor::EventsExecutor(py::object context) inspect_iscoroutine_(py::module_::import("inspect").attr("iscoroutine")), inspect_signature_(py::module_::import("inspect").attr("signature")), rclpy_task_(py::module_::import("rclpy.task").attr("Task")), + rclpy_future_(py::module_::import("rclpy.task").attr("Future")), signal_callback_([this]() {events_queue_.Stop();}), rcl_callback_manager_(&events_queue_), timers_manager_( @@ -77,6 +78,12 @@ pybind11::object EventsExecutor::create_task( return task; } +pybind11::object EventsExecutor::create_future() +{ + using py::literals::operator""_a; + return rclpy_future_("executor"_a = py::cast(this)); +} + bool EventsExecutor::shutdown(std::optional timeout) { // NOTE: The rclpy context can invoke this with a lock on the context held. Therefore we must @@ -881,6 +888,7 @@ void define_events_executor(py::object module) .def(py::init(), py::arg("context")) .def_property_readonly("context", &EventsExecutor::get_context) .def("create_task", &EventsExecutor::create_task, py::arg("callback")) + .def("create_future", &EventsExecutor::create_future) .def("shutdown", &EventsExecutor::shutdown, py::arg("timeout_sec") = py::none()) .def("add_node", &EventsExecutor::add_node, py::arg("node")) .def("remove_node", &EventsExecutor::remove_node, py::arg("node")) diff --git a/rclpy/src/rclpy/events_executor/events_executor.hpp b/rclpy/src/rclpy/events_executor/events_executor.hpp index 82ffe45ed..7d67d20ad 100644 --- a/rclpy/src/rclpy/events_executor/events_executor.hpp +++ b/rclpy/src/rclpy/events_executor/events_executor.hpp @@ -66,6 +66,7 @@ class EventsExecutor pybind11::object get_context() const {return rclpy_context_;} pybind11::object create_task( pybind11::object callback, pybind11::args args = {}, const pybind11::kwargs & kwargs = {}); + pybind11::object create_future(); bool shutdown(std::optional timeout_sec = {}); bool add_node(pybind11::object node); void remove_node(pybind11::handle node); @@ -167,6 +168,7 @@ class EventsExecutor const pybind11::object inspect_iscoroutine_; const pybind11::object inspect_signature_; const pybind11::object rclpy_task_; + const pybind11::object rclpy_future_; EventsQueue events_queue_; ScopedSignalCallback signal_callback_; diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 5075318c4..1c7afb66c 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -452,7 +452,7 @@ def timer_callback() -> None: timer = self.node.create_timer(0.003, timer_callback) # Timeout - future = Future() + future = executor.create_future() self.assertFalse(future.done()) start = time.perf_counter() executor.spin_until_future_complete(future=future, timeout_sec=0.1) @@ -479,7 +479,7 @@ def set_future_result(future): future.set_result('finished') # Future complete timeout_sec > 0 - future = Future() + future = executor.create_future() self.assertFalse(future.done()) t = threading.Thread(target=lambda: set_future_result(future)) t.start() @@ -488,7 +488,7 @@ def set_future_result(future): self.assertEqual(future.result(), 'finished') # Future complete timeout_sec = None - future = Future() + future = executor.create_future() self.assertFalse(future.done()) t = threading.Thread(target=lambda: set_future_result(future)) t.start() @@ -497,7 +497,7 @@ def set_future_result(future): self.assertEqual(future.result(), 'finished') # Future complete timeout < 0 - future = Future() + future = executor.create_future() self.assertFalse(future.done()) t = threading.Thread(target=lambda: set_future_result(future)) t.start() @@ -519,7 +519,7 @@ def timer_callback() -> None: timer = self.node.create_timer(0.003, timer_callback) # Do not wait timeout_sec = 0 - future = Future() + future = executor.create_future() self.assertFalse(future.done()) executor.spin_until_future_complete(future=future, timeout_sec=0) self.assertFalse(future.done()) @@ -602,7 +602,7 @@ def test_single_threaded_spin_once_until_future(self): with self.subTest(cls=cls): executor = cls(context=self.context) - future = Future(executor=executor) + future = executor.create_future() # Setup a thread to spin_once_until_future_complete, which will spin # for a maximum of 10 seconds. @@ -630,7 +630,7 @@ def test_multi_threaded_spin_once_until_future(self): self.assertIsNotNone(self.node.handle) executor = MultiThreadedExecutor(context=self.context) - future = Future(executor=executor) + future: Future[bool] = executor.create_future() # Setup a thread to spin_once_until_future_complete, which will spin # for a maximum of 10 seconds. @@ -679,7 +679,7 @@ def timer2_callback() -> None: timer2 = self.node.create_timer(1.5, timer2_callback, callback_group) executor.add_node(self.node) - future = Future(executor=executor) + future = executor.create_future() executor.spin_until_future_complete(future, 4) assert count == 2 @@ -689,6 +689,17 @@ def timer2_callback() -> None: self.node.destroy_timer(timer1) self.node.destroy_client(cli) + def test_create_future_returns_future_with_executor_attached(self) -> None: + self.assertIsNotNone(self.node.handle) + for cls in [SingleThreadedExecutor, MultiThreadedExecutor, EventsExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + try: + fut = executor.create_future() + self.assertEqual(executor, fut._executor()) + finally: + executor.shutdown() + if __name__ == '__main__': unittest.main()