diff --git a/rust/system/src/execution/dispatcher.rs b/rust/system/src/execution/dispatcher.rs index 764ae1927be..2d3dce4c87d 100644 --- a/rust/system/src/execution/dispatcher.rs +++ b/rust/system/src/execution/dispatcher.rs @@ -16,10 +16,13 @@ use std::collections::VecDeque; use std::fmt::Debug; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::time::Duration; use thiserror::Error; use tokio::runtime::Runtime; use tracing::{trace_span, Instrument, Span}; +const IO_RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + /// The dispatcher is responsible for distributing tasks to worker threads. /// It is a component that receives tasks and distributes them to worker threads. /** @@ -67,7 +70,7 @@ pub struct Dispatcher { task_queue: VecDeque<(TaskMessage, Span)>, waiters: Vec, active_io_tasks: Arc, - io_runtime: Arc, + io_runtime: Option, worker_handles: Arc>>>, metrics: DispatcherMetrics, } @@ -80,6 +83,15 @@ impl Debug for Dispatcher { } } +impl Drop for Dispatcher { + fn drop(&mut self) { + if let Some(io_runtime) = self.io_runtime.take() { + // Drop can run inside a Tokio task; use the nonblocking shutdown fallback. + io_runtime.shutdown_background(); + } + } +} + #[derive(Debug, Clone)] struct DispatcherMetrics { task_queue_depth: opentelemetry::metrics::Histogram, @@ -179,7 +191,7 @@ impl Dispatcher { task_queue: VecDeque::new(), waiters: Vec::new(), active_io_tasks: Arc::new(AtomicU64::new(config.active_io_tasks as u64)), - io_runtime: Arc::new(io_runtime), + io_runtime: Some(io_runtime), worker_handles: Arc::new(Mutex::new(Vec::new())), metrics: DispatcherMetrics::new(), } @@ -307,10 +319,22 @@ impl Dispatcher { .queue_latency_ms .record(duration_ms(task_created_at.elapsed()), &task_kv); self.record_depths(); - self.io_runtime.spawn(async move { - task.run().instrument(child_span).await; - drop(counter); - }); + match &self.io_runtime { + Some(io_runtime) => { + io_runtime.spawn(async move { + task.run().instrument(child_span).await; + drop(counter); + }); + } + None => { + task.abort().await; + self.metrics.task_abort_total.add( + 1, + &task_attrs_with(task_type, operator, "reason", "dispatcher_stopped"), + ); + drop(counter); + } + } } OperatorType::Other => { // If a worker is waiting for a task, send it to the worker in FIFO order. @@ -435,12 +459,15 @@ impl TaskRequestMessage { enum DispatcherStopError { #[error("Failed to stop worker thread: {0}")] JoinError(ConsumeJoinHandleError), + #[error("Failed to shut down IO runtime: {0}")] + IoRuntimeShutdownError(tokio::task::JoinError), } impl ChromaError for DispatcherStopError { fn code(&self) -> chroma_error::ErrorCodes { match self { DispatcherStopError::JoinError(_) => chroma_error::ErrorCodes::Internal, + DispatcherStopError::IoRuntimeShutdownError(_) => chroma_error::ErrorCodes::Internal, } } } @@ -475,6 +502,14 @@ impl Component for Dispatcher { .map_err(|e| DispatcherStopError::JoinError(e).boxed())?; } + if let Some(io_runtime) = self.io_runtime.take() { + tokio::task::spawn_blocking(move || { + io_runtime.shutdown_timeout(IO_RUNTIME_SHUTDOWN_TIMEOUT); + }) + .await + .map_err(|e| DispatcherStopError::IoRuntimeShutdownError(e).boxed())?; + } + Ok(()) } } @@ -889,4 +924,36 @@ mod tests { let first_result = first_result_rx.await.unwrap(); assert_eq!(first_result.into_inner().unwrap(), "first"); } + + #[tokio::test] + async fn test_dispatcher_shutdown_does_not_drop_runtime_in_async_context() { + let system = System::new(); + let dispatcher = Dispatcher::new(DispatcherConfig { + num_worker_threads: 0, + task_queue_limit: 0, + dispatcher_queue_size: 10, + worker_queue_size: 1, + active_io_tasks: 1, + cpu_affinity_num_cores: None, + io_affinity_num_cores: None, + }); + let mut dispatcher_handle = system.start_component(dispatcher); + + tokio::task::yield_now().await; + dispatcher_handle.stop(); + dispatcher_handle.join().await.unwrap(); + } + + #[tokio::test] + async fn test_dispatcher_drop_does_not_drop_runtime_in_async_context() { + let _dispatcher = Dispatcher::new(DispatcherConfig { + num_worker_threads: 0, + task_queue_limit: 0, + dispatcher_queue_size: 10, + worker_queue_size: 1, + active_io_tasks: 1, + cpu_affinity_num_cores: None, + io_affinity_num_cores: None, + }); + } }