Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 73 additions & 6 deletions rust/system/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ use std::collections::VecDeque;
use std::fmt::Debug;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::runtime::Runtime;
use tracing::{trace_span, Instrument, Span};

const IO_RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);

/// The dispatcher is responsible for distributing tasks to worker threads.
/// It is a component that receives tasks and distributes them to worker threads.
/**
Expand Down Expand Up @@ -67,7 +70,7 @@ pub struct Dispatcher {
task_queue: VecDeque<(TaskMessage, Span)>,
waiters: Vec<TaskRequestMessage>,
active_io_tasks: Arc<AtomicU64>,
io_runtime: Arc<Runtime>,
io_runtime: Option<Runtime>,
worker_handles: Arc<Mutex<Vec<ComponentHandle<WorkerThread>>>>,
metrics: DispatcherMetrics,
}
Expand All @@ -80,6 +83,15 @@ impl Debug for Dispatcher {
}
}

impl Drop for Dispatcher {
fn drop(&mut self) {
if let Some(io_runtime) = self.io_runtime.take() {
// Drop can run inside a Tokio task; use the nonblocking shutdown fallback.
io_runtime.shutdown_background();
}
}
}

#[derive(Debug, Clone)]
struct DispatcherMetrics {
task_queue_depth: opentelemetry::metrics::Histogram<u64>,
Expand Down Expand Up @@ -179,7 +191,7 @@ impl Dispatcher {
task_queue: VecDeque::new(),
waiters: Vec::new(),
active_io_tasks: Arc::new(AtomicU64::new(config.active_io_tasks as u64)),
io_runtime: Arc::new(io_runtime),
io_runtime: Some(io_runtime),
worker_handles: Arc::new(Mutex::new(Vec::new())),
metrics: DispatcherMetrics::new(),
}
Expand Down Expand Up @@ -307,10 +319,22 @@ impl Dispatcher {
.queue_latency_ms
.record(duration_ms(task_created_at.elapsed()), &task_kv);
self.record_depths();
self.io_runtime.spawn(async move {
task.run().instrument(child_span).await;
drop(counter);
});
match &self.io_runtime {
Some(io_runtime) => {
io_runtime.spawn(async move {
task.run().instrument(child_span).await;
drop(counter);
});
}
None => {
task.abort().await;
self.metrics.task_abort_total.add(
1,
&task_attrs_with(task_type, operator, "reason", "dispatcher_stopped"),
);
drop(counter);
}
}
}
OperatorType::Other => {
// If a worker is waiting for a task, send it to the worker in FIFO order.
Expand Down Expand Up @@ -435,12 +459,15 @@ impl TaskRequestMessage {
enum DispatcherStopError {
#[error("Failed to stop worker thread: {0}")]
JoinError(ConsumeJoinHandleError),
#[error("Failed to shut down IO runtime: {0}")]
IoRuntimeShutdownError(tokio::task::JoinError),
}

impl ChromaError for DispatcherStopError {
fn code(&self) -> chroma_error::ErrorCodes {
match self {
DispatcherStopError::JoinError(_) => chroma_error::ErrorCodes::Internal,
DispatcherStopError::IoRuntimeShutdownError(_) => chroma_error::ErrorCodes::Internal,
}
}
}
Expand Down Expand Up @@ -475,6 +502,14 @@ impl Component for Dispatcher {
.map_err(|e| DispatcherStopError::JoinError(e).boxed())?;
}

if let Some(io_runtime) = self.io_runtime.take() {
tokio::task::spawn_blocking(move || {
io_runtime.shutdown_timeout(IO_RUNTIME_SHUTDOWN_TIMEOUT);
})
.await
.map_err(|e| DispatcherStopError::IoRuntimeShutdownError(e).boxed())?;
}

Ok(())
}
}
Expand Down Expand Up @@ -889,4 +924,36 @@ mod tests {
let first_result = first_result_rx.await.unwrap();
assert_eq!(first_result.into_inner().unwrap(), "first");
}

#[tokio::test]
async fn test_dispatcher_shutdown_does_not_drop_runtime_in_async_context() {
let system = System::new();
let dispatcher = Dispatcher::new(DispatcherConfig {
num_worker_threads: 0,
task_queue_limit: 0,
dispatcher_queue_size: 10,
worker_queue_size: 1,
active_io_tasks: 1,
cpu_affinity_num_cores: None,
io_affinity_num_cores: None,
});
let mut dispatcher_handle = system.start_component(dispatcher);

tokio::task::yield_now().await;
dispatcher_handle.stop();
dispatcher_handle.join().await.unwrap();
}

#[tokio::test]
async fn test_dispatcher_drop_does_not_drop_runtime_in_async_context() {
let _dispatcher = Dispatcher::new(DispatcherConfig {
num_worker_threads: 0,
task_queue_limit: 0,
dispatcher_queue_size: 10,
worker_queue_size: 1,
active_io_tasks: 1,
cpu_affinity_num_cores: None,
io_affinity_num_cores: None,
});
}
}
Loading