Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 171 additions & 10 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,14 +1204,19 @@ fn should_track_plan_process(stmt: Option<&Statement>, plan: &LogicalPlan) -> bo
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Barrier};
use std::task::{Context, Poll};
use std::thread;
use std::time::{Duration, Instant};

use api::v1::meta::{ProcedureDetailResponse, ReconcileRequest, ReconcileResponse};
use catalog::process_manager::ProcessManager;
use common_base::Plugins;
use common_error::ext::BoxedError;
use common_error::status_code::StatusCode;
use common_meta::cache::LayeredCacheRegistryBuilder;
use common_meta::kv_backend::memory::MemoryKvBackend;
use common_meta::procedure_executor::{ExecutorContext, ProcedureExecutor};
Expand All @@ -1220,17 +1225,23 @@ mod tests {
MigrateRegionRequest, MigrateRegionResponse, ProcedureStateResponse,
};
use common_query::Output;
use common_recordbatch::{
OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream,
};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_expr::dml::InsertOp;
use datafusion_expr::{LogicalPlanBuilder, LogicalTableSource};
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema as GtSchema};
use datatypes::schema::{ColumnSchema, Schema as GtSchema, SchemaRef as GtSchemaRef};
use query::query_engine::options::QueryOptions;
use session::context::{Channel, ConnInfo, QueryContext, QueryContextBuilder};
use sql::dialect::GreptimeDbDialect;
use store_api::data_source::DataSource;
use store_api::storage::ScanRequest;
use strfmt::Format;
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
use table::metadata::{FilterPushDownType, TableInfo, TableInfoBuilder, TableMetaBuilder};
use table::test_util::EmptyTable;
use table::{Table, TableRef};
use tokio::sync::{mpsc, oneshot};

use super::*;
Expand Down Expand Up @@ -1292,6 +1303,64 @@ mod tests {
}
}

struct PendingRecordBatchStream {
schema: GtSchemaRef,
polled_tx: Option<oneshot::Sender<()>>,
_finish_tx: oneshot::Sender<()>,
finish_rx: Pin<Box<oneshot::Receiver<()>>>,
}

impl RecordBatchStream for PendingRecordBatchStream {
fn schema(&self) -> GtSchemaRef {
self.schema.clone()
}

fn output_ordering(&self) -> Option<&[OrderOption]> {
None
}

fn metrics(&self) -> Option<common_recordbatch::adapter::RecordBatchMetrics> {
None
}
}

impl Stream for PendingRecordBatchStream {
type Item = common_recordbatch::error::Result<RecordBatch>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(polled_tx) = self.polled_tx.take() {
let _ = polled_tx.send(());
}

match self.finish_rx.as_mut().poll(cx) {
Poll::Ready(_) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}

impl Unpin for PendingRecordBatchStream {}

struct PendingDataSource {
schema: GtSchemaRef,
polled_tx: std::sync::Mutex<Option<oneshot::Sender<()>>>,
}

impl DataSource for PendingDataSource {
fn get_stream(
&self,
_request: ScanRequest,
) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
let (finish_tx, finish_rx) = oneshot::channel();
Ok(Box::pin(PendingRecordBatchStream {
schema: self.schema.clone(),
polled_tx: self.polled_tx.lock().unwrap().take(),
_finish_tx: finish_tx,
finish_rx: Box::pin(finish_rx),
}))
}
}

struct NoopProcedureExecutor;

#[async_trait::async_trait]
Expand Down Expand Up @@ -1364,7 +1433,7 @@ mod tests {
)
}

fn test_table(table_id: u32, table_name: &str) -> table::TableRef {
fn test_table_info(table_id: u32, table_name: &str) -> TableInfo {
let schema = Arc::new(GtSchema::new(vec![
ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false),
ColumnSchema::new(
Expand All @@ -1381,35 +1450,74 @@ mod tests {
.next_column_id(1024)
.build()
.unwrap();
let table_info = TableInfoBuilder::new(table_name, table_meta)

TableInfoBuilder::new(table_name, table_meta)
.table_id(table_id)
.build()
.unwrap();
.unwrap()
Comment thread
killme2008 marked this conversation as resolved.
Outdated
}

fn test_table(table_id: u32, table_name: &str) -> table::TableRef {
let table_info = test_table_info(table_id, table_name);
EmptyTable::from_table_info(&table_info)
}

fn pending_table(
table_id: u32,
table_name: &str,
polled_tx: oneshot::Sender<()>,
) -> table::TableRef {
let table_info = test_table_info(table_id, table_name);
let data_source = Arc::new(PendingDataSource {
schema: table_info.meta.schema.clone(),
polled_tx: std::sync::Mutex::new(Some(polled_tx)),
});

Arc::new(Table::new(
Arc::new(table_info),
FilterPushDownType::Unsupported,
data_source,
))
}

async fn test_instance_with_tables(source_table: TableRef, target_table: TableRef) -> Instance {
test_instance_with_plugins(source_table, target_table, Plugins::new()).await
}

async fn test_instance_with_insert_select_interceptor(
interceptor: SqlQueryInterceptorRef<Error>,
) -> Instance {
let plugins = Plugins::new();
plugins.insert::<SqlQueryInterceptorRef<Error>>(interceptor);

test_instance_with_plugins(
test_table(1024, "source"),
test_table(1025, "target"),
plugins,
)
.await
}

async fn test_instance_with_plugins(
source_table: TableRef,
target_table: TableRef,
plugins: Plugins,
) -> Instance {
let kv_backend = Arc::new(MemoryKvBackend::new());
let process_manager = Arc::new(ProcessManager::new("test-frontend".to_string(), None));
let catalog_manager =
catalog::memory::MemoryCatalogManager::new_with_table(test_table(1024, "source"));
let catalog_manager = catalog::memory::MemoryCatalogManager::new_with_table(source_table);
catalog_manager
.register_table_sync(catalog::RegisterTableRequest {
catalog: "greptime".to_string(),
schema: "public".to_string(),
table_name: "target".to_string(),
table_id: 1025,
table: test_table(1025, "target"),
table: target_table,
})
.unwrap();
catalog_manager.register_process_list_table(process_manager.clone());

let cache_registry = test_cache_registry(kv_backend.clone());
let plugins = Plugins::new();
plugins.insert::<SqlQueryInterceptorRef<Error>>(interceptor);

FrontendBuilder::new(
FrontendOptions::default(),
Expand Down Expand Up @@ -1618,6 +1726,59 @@ mod tests {
insert_task.await.unwrap().unwrap();
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_kill_query_cancels_insert_select() {
assert_kill_cancels_insert_select("KILL QUERY 4242").await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_kill_process_id_cancels_insert_select() {
assert_kill_cancels_insert_select("KILL 'test-frontend/4242'").await;
}

async fn assert_kill_cancels_insert_select(kill_sql: &str) {
let insert_sql = "INSERT INTO target SELECT * FROM source";
let (source_polled_tx, source_polled_rx) = oneshot::channel();
let instance = Arc::new(
test_instance_with_tables(
pending_table(1024, "source", source_polled_tx),
test_table(1025, "target"),
)
.await,
);

let insert_task = tokio::spawn({
let instance = instance.clone();
async move { execute_one_sql(&instance, insert_sql, test_query_ctx(4242)).await }
});

tokio::time::timeout(Duration::from_secs(5), source_polled_rx)
.await
.unwrap()
.unwrap();
Comment thread
killme2008 marked this conversation as resolved.
Outdated

let output = execute_one_sql(&instance, kill_sql, test_query_ctx(43))
.await
.unwrap();
assert!(matches!(output.data, OutputData::AffectedRows(1)));

let err = tokio::time::timeout(Duration::from_secs(5), insert_task)
.await
.unwrap()
.unwrap()
.unwrap_err();
Comment thread
killme2008 marked this conversation as resolved.
Outdated
assert_eq!(StatusCode::Cancelled, err.status_code());

let output = execute_one_sql(&instance, "SHOW PROCESSLIST", test_query_ctx(43))
.await
.unwrap();
let process_list = output.data.pretty_print().await;
assert!(
!process_list.contains(insert_sql),
"process list still contains killed insert:\n{process_list}"
);
}

fn insert_dml_plan() -> LogicalPlan {
let schema = SchemaRef::new(Schema::new(vec![Field::new(
"value",
Expand Down
Loading