diff --git a/src/query/src/optimizer/type_conversion.rs b/src/query/src/optimizer/type_conversion.rs index 21f95e51566f..19b8eccf5620 100644 --- a/src/query/src/optimizer/type_conversion.rs +++ b/src/query/src/optimizer/type_conversion.rs @@ -309,9 +309,12 @@ mod tests { use std::collections::HashMap; use std::sync::Arc; - use datafusion_common::arrow::datatypes::Field; + use datafusion::execution::SessionStateBuilder; + use datafusion_common::arrow::datatypes::{Field, TimeUnit as ArrowTimeUnit}; use datafusion_common::{Column, DFSchema}; + use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{Literal, LogicalPlanBuilder}; + use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_sql::TableReference; use session::context::QueryContext; @@ -384,6 +387,102 @@ mod tests { ); } + /// TODO(discord9): update this once datafusion update and fixes lossy downcast problem + #[test] + fn test_datafusion_simplifier_unwraps_timestamp_precision_cast_comparisons() { + let schema = Arc::new( + DFSchema::from_unqualified_fields( + vec![Arc::new(Field::new( + "ts", + DataType::Timestamp(ArrowTimeUnit::Nanosecond, None), + false, + ))] + .into(), + HashMap::new(), + ) + .unwrap(), + ); + let simplifier = ExprSimplifier::new(SimplifyContext::default().with_schema(schema)); + + let ts = Expr::Column(Column::from_name("ts")); + let cast_ts = Expr::Cast(datafusion_expr::Cast { + expr: Box::new(ts.clone()), + data_type: DataType::Timestamp(ArrowTimeUnit::Millisecond, None), + }); + let ms_lit = ScalarValue::TimestampMillisecond(Some(1000), None).lit(); + let ns_lit = ScalarValue::TimestampNanosecond(Some(1_000_000_000), None).lit(); + + let simplify = |expr| simplifier.simplify(expr).unwrap(); + + assert_eq!( + simplify(cast_ts.clone().eq(ms_lit.clone())), + ts.clone().eq(ns_lit.clone()), + ); + assert_eq!( + simplify(cast_ts.clone().not_eq(ms_lit.clone())), + ts.clone().not_eq(ns_lit.clone()), + ); + assert_eq!( + simplify(cast_ts.clone().lt(ms_lit.clone())), + ts.clone().lt(ns_lit.clone()), + ); + assert_eq!( + simplify(cast_ts.clone().lt_eq(ms_lit.clone())), + ts.clone().lt_eq(ns_lit.clone()), + ); + assert_eq!( + simplify(cast_ts.clone().gt(ms_lit.clone())), + ts.clone().gt(ns_lit.clone()), + ); + assert_eq!( + simplify(cast_ts.clone().gt_eq(ms_lit.clone())), + ts.clone().gt_eq(ns_lit.clone()), + ); + assert_eq!(simplify(ms_lit.lt(cast_ts)), ts.gt(ns_lit),); + } + + #[test] + fn test_datafusion_optimizer_pushes_filter_through_timestamp_cast_projection() { + let cast_ts = Expr::Cast(datafusion_expr::Cast { + expr: Box::new(Expr::Column(Column::from_name("column1"))), + data_type: DataType::Timestamp(ArrowTimeUnit::Millisecond, None), + }); + let plan = LogicalPlanBuilder::values(vec![vec![ + ScalarValue::TimestampNanosecond(Some(1_000_000_123), None).lit(), + 1_i64.lit(), + ]]) + .unwrap() + .project(vec![ + cast_ts.alias("ts_ms"), + Expr::Column(Column::from_name("column2")).alias("val"), + ]) + .unwrap() + .filter( + Expr::Column(Column::from_name("ts_ms")).eq(ScalarValue::TimestampMillisecond( + Some(1000), + None, + ) + .lit()), + ) + .unwrap() + .build() + .unwrap(); + + let session_state = SessionStateBuilder::new().with_default_features().build(); + let optimized_plan = session_state.optimize(&plan).unwrap(); + let optimized = optimized_plan.display_indent().to_string(); + + assert!(optimized.contains("Projection:"), "{optimized}"); + assert!( + optimized.contains("Filter: column1 = TimestampNanosecond(1000000000, None)"), + "{optimized}" + ); + assert!( + optimized.find("Projection:") < optimized.find("Filter:"), + "{optimized}" + ); + } + #[test] fn test_convert_timestamp_str() { use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit; diff --git a/src/query/src/tests/time_range_filter_test.rs b/src/query/src/tests/time_range_filter_test.rs index 0a3df116cb1d..0cfc112a56cc 100644 --- a/src/query/src/tests/time_range_filter_test.rs +++ b/src/query/src/tests/time_range_filter_test.rs @@ -22,7 +22,10 @@ use common_recordbatch::{RecordBatch, SendableRecordBatchStream}; use common_time::Timestamp; use common_time::range::TimestampRange; use common_time::timestamp::TimeUnit; +use datafusion_common::ScalarValue; use datafusion_expr::expr::Expr; +use datafusion_expr::{col, lit}; +use datatypes::arrow::datatypes::{DataType, TimeUnit as ArrowTimeUnit}; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{Int64Vector, TimestampMillisecondVector}; @@ -135,6 +138,77 @@ impl TimeRangeTester { } } +fn cast_to_ms_col(name: &str) -> Expr { + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(col(name)), + data_type: DataType::Timestamp(ArrowTimeUnit::Millisecond, None), + }) +} + +fn ms_lit(value: i64) -> Expr { + lit(ScalarValue::TimestampMillisecond(Some(value), None)) +} + +#[test] +fn test_casted_time_index_precision_boundaries() { + let cast_ts = cast_to_ms_col("ts"); + + let us_bucket = TimestampRange::with_unit(1_000_000, 1_001_000, TimeUnit::Microsecond).unwrap(); + assert_eq!( + us_bucket, + build_time_range_predicate( + "ts", + TimeUnit::Microsecond, + &[cast_ts.clone().eq(ms_lit(1000))], + ) + ); + assert!(us_bucket.contains(&Timestamp::new(1_000_999, TimeUnit::Microsecond))); + assert!(!us_bucket.contains(&Timestamp::new(1_001_000, TimeUnit::Microsecond))); + + assert_eq!( + TimestampRange::until_end(Timestamp::new(1_001_000, TimeUnit::Microsecond), false), + build_time_range_predicate( + "ts", + TimeUnit::Microsecond, + &[cast_ts.clone().lt_eq(ms_lit(1000))], + ) + ); + + let ns_bucket = + TimestampRange::with_unit(1_000_000_000, 1_001_000_000, TimeUnit::Nanosecond).unwrap(); + assert_eq!( + ns_bucket, + build_time_range_predicate( + "ts", + TimeUnit::Nanosecond, + &[cast_ts.clone().eq(ms_lit(1000))], + ) + ); + assert!(ns_bucket.contains(&Timestamp::new(1_000_999_999, TimeUnit::Nanosecond))); + assert!(!ns_bucket.contains(&Timestamp::new(1_001_000_000, TimeUnit::Nanosecond))); + + assert_eq!( + TimestampRange::from_start(Timestamp::new(1_000_000_000, TimeUnit::Nanosecond)), + build_time_range_predicate( + "ts", + TimeUnit::Nanosecond, + &[cast_ts.clone().gt_eq(ms_lit(1000))], + ) + ); + assert_eq!( + TimestampRange::from_start(Timestamp::new(1_001_000_000, TimeUnit::Nanosecond)), + build_time_range_predicate( + "ts", + TimeUnit::Nanosecond, + &[cast_ts.clone().gt(ms_lit(1000))], + ) + ); + assert_eq!( + TimestampRange::until_end(Timestamp::new(1_000_000_000, TimeUnit::Nanosecond), false), + build_time_range_predicate("ts", TimeUnit::Nanosecond, &[cast_ts.lt(ms_lit(1000))],) + ); +} + #[tokio::test] async fn test_range_filter() { let tester = create_test_engine(); diff --git a/src/table/src/predicate.rs b/src/table/src/predicate.rs index 2c9ac41560be..962b0b70f49c 100644 --- a/src/table/src/predicate.rs +++ b/src/table/src/predicate.rs @@ -234,6 +234,10 @@ fn extract_from_binary_expr( op: &Operator, right: &Expr, ) -> Option { + if let Some(range) = get_casted_timestamp_filter(ts_col_name, ts_col_unit, left, op, right) { + return Some(range); + } + match op { Operator::Eq => get_timestamp_filter(ts_col_name, left, right) .and_then(|(ts, _)| ts.convert_to(ts_col_unit)) @@ -310,7 +314,105 @@ fn extract_from_binary_expr( } } +fn get_casted_timestamp_filter( + ts_col_name: &str, + ts_col_unit: TimeUnit, + left: &Expr, + op: &Operator, + right: &Expr, +) -> Option { + let (lit, op) = match (left, right) { + (expr, Expr::Literal(scalar, _)) if is_casted_time_index(expr, ts_col_name) => { + (scalar, *op) + } + (Expr::Literal(scalar, _), expr) if is_casted_time_index(expr, ts_col_name) => { + (scalar, reverse_operator(op)?) + } + _ => return None, + }; + + return_none_if_utf8!(lit); + let ScalarValue::TimestampMillisecond(Some(lit_ms), None) = lit else { + return None; + }; + + // Avoid epoch-boundary cases until negative timestamp cast semantics are tested. + // With truncation-toward-zero, predicates such as `CAST(ts AS ms) >= 0` may include + // negative sub-millisecond raw timestamps, so `ts >= 0` would be too narrow. + if *lit_ms <= 0 { + return None; + } + + let lit_ts = Timestamp::new_millisecond(*lit_ms); + let next_ms = lit_ms.checked_add(1).map(Timestamp::new_millisecond)?; + + match &op { + Operator::Eq => TimestampRange::new( + lit_ts.convert_to(ts_col_unit)?, + next_ms.convert_to(ts_col_unit)?, + ), + Operator::Lt => lit_ts + .convert_to(ts_col_unit) + .map(|ts| TimestampRange::until_end(ts, false)), + Operator::LtEq => next_ms + .convert_to(ts_col_unit) + .map(|ts| TimestampRange::until_end(ts, false)), + Operator::Gt => next_ms + .convert_to(ts_col_unit) + .map(TimestampRange::from_start), + Operator::GtEq => lit_ts + .convert_to(ts_col_unit) + .map(TimestampRange::from_start), + _ => None, + } +} + +fn is_casted_time_index(expr: &Expr, ts_col_name: &str) -> bool { + let Expr::Cast(cast) = expr else { + return false; + }; + + if !matches!( + &cast.data_type, + datatypes::arrow::datatypes::DataType::Timestamp( + datatypes::arrow::datatypes::TimeUnit::Millisecond, + None + ) + ) { + return false; + } + + let Expr::Column(col) = cast.expr.as_ref() else { + return false; + }; + + col.name == ts_col_name +} + +fn reverse_operator(op: &Operator) -> Option { + match op { + Operator::Eq => Some(Operator::Eq), + Operator::Lt => Some(Operator::Gt), + Operator::LtEq => Some(Operator::GtEq), + Operator::Gt => Some(Operator::Lt), + Operator::GtEq => Some(Operator::LtEq), + _ => None, + } +} + fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> { + // Design note for extracting `CAST(time_index AS Timestamp(ms)) literal` (#7913): + // this helper currently accepts only raw `time_index literal` filters. If casted + // time-index filters are added, derive a pruning range from the cast bucket, not by + // blindly replacing `CAST(ts)` with `ts`. For positive millisecond literals and finer + // raw units whose cast-to-ms semantics drop the sub-ms remainder, the safe mappings are: + // `CAST(ts)=L` => `[L, L+1ms)`, `< L` => `(-inf, L)`, `<= L` => `(-inf, L+1ms)`, + // `> L` => `[L+1ms, +inf)`, and `>= L` => `[L, +inf)`, converted to the raw unit. + // Literal-left comparisons can be normalized by reversing the operator. Keep the original + // filter applied when the extracted range is only a pruning approximation, and return + // `None` instead of risking a narrower-than-true range for unsupported operators (`!=`), + // timezone/overflow ambiguity, or negative literals until Arrow/DataFusion cast semantics + // for negative epochs are covered by tests. let (col, lit, reverse) = match (left, right) { (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false), (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true), @@ -421,6 +523,24 @@ mod tests { ); } + fn check_build_predicate_with_unit(expr: Expr, unit: TimeUnit, expect: TimestampRange) { + assert_eq!(expect, build_time_range_predicate("ts", unit, &[expr])); + } + + fn cast_to_ms_col(name: &str) -> Expr { + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(col(name)), + data_type: DataType::Timestamp( + datatypes::arrow::datatypes::TimeUnit::Millisecond, + None, + ), + }) + } + + fn ms_lit(value: i64) -> Expr { + lit(ScalarValue::TimestampMillisecond(Some(value), None)) + } + #[test] fn test_gt() { // ts > 1ms @@ -577,6 +697,108 @@ mod tests { ); } + #[test] + fn test_casted_time_index_filter() { + let ts = cast_to_ms_col("ts"); + let unit = TimeUnit::Microsecond; + + check_build_predicate_with_unit( + ts.clone().eq(ms_lit(1000)), + unit, + TimestampRange::new( + Timestamp::new(1_000_000, unit), + Timestamp::new(1_001_000, unit), + ) + .unwrap(), + ); + check_build_predicate_with_unit( + ts.clone().lt(ms_lit(1000)), + unit, + TimestampRange::until_end(Timestamp::new(1_000_000, unit), false), + ); + check_build_predicate_with_unit( + ts.clone().lt_eq(ms_lit(1000)), + unit, + TimestampRange::until_end(Timestamp::new(1_001_000, unit), false), + ); + check_build_predicate_with_unit( + ts.clone().gt(ms_lit(1000)), + unit, + TimestampRange::from_start(Timestamp::new(1_001_000, unit)), + ); + check_build_predicate_with_unit( + ts.gt_eq(ms_lit(1000)), + unit, + TimestampRange::from_start(Timestamp::new(1_000_000, unit)), + ); + } + + #[test] + fn test_casted_time_index_filter_literal_left() { + let ts = cast_to_ms_col("ts"); + let unit = TimeUnit::Nanosecond; + + check_build_predicate_with_unit( + ms_lit(1000).lt(ts.clone()), + unit, + TimestampRange::from_start(Timestamp::new(1_001_000_000, unit)), + ); + check_build_predicate_with_unit( + ms_lit(1000).lt_eq(ts.clone()), + unit, + TimestampRange::from_start(Timestamp::new(1_000_000_000, unit)), + ); + check_build_predicate_with_unit( + ms_lit(1000).gt(ts.clone()), + unit, + TimestampRange::until_end(Timestamp::new(1_000_000_000, unit), false), + ); + check_build_predicate_with_unit( + ms_lit(1000).gt_eq(ts.clone()), + unit, + TimestampRange::until_end(Timestamp::new(1_001_000_000, unit), false), + ); + check_build_predicate_with_unit( + ms_lit(1000).eq(ts), + unit, + TimestampRange::new( + Timestamp::new(1_000_000_000, unit), + Timestamp::new(1_001_000_000, unit), + ) + .unwrap(), + ); + } + + #[test] + fn test_casted_time_index_filter_unsupported_cases() { + check_build_predicate_with_unit( + cast_to_ms_col("other").eq(ms_lit(1000)), + TimeUnit::Microsecond, + TimestampRange::min_to_max(), + ); + + let cast_to_second = Expr::Cast(datafusion_expr::Cast { + expr: Box::new(col("ts")), + data_type: DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Second, None), + }); + check_build_predicate_with_unit( + cast_to_second.eq(ms_lit(1000)), + TimeUnit::Microsecond, + TimestampRange::min_to_max(), + ); + + check_build_predicate_with_unit( + cast_to_ms_col("ts").gt_eq(ms_lit(0)), + TimeUnit::Microsecond, + TimestampRange::min_to_max(), + ); + check_build_predicate_with_unit( + cast_to_ms_col("ts").eq(lit(ScalarValue::TimestampMicrosecond(Some(1000), None))), + TimeUnit::Microsecond, + TimestampRange::min_to_max(), + ); + } + async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc) { let path = dir .path() diff --git a/tests/cases/standalone/optimizer/cast_time_index_filter_pushdown.result b/tests/cases/standalone/optimizer/cast_time_index_filter_pushdown.result new file mode 100644 index 000000000000..11c7be683bb3 --- /dev/null +++ b/tests/cases/standalone/optimizer/cast_time_index_filter_pushdown.result @@ -0,0 +1,76 @@ +-- Corresponding to issue #7913. +-- Verify a filter over a projected millisecond cast of a non-ms time index +-- is passed down to scan as a casted time-index predicate for pruning. +CREATE TABLE cast_time_index_filter_pushdown ( + ts TIMESTAMP_NS NOT NULL TIME INDEX, + val BIGINT, +) ENGINE = mito +WITH + (append_mode = 'true', sst_format = 'flat'); + +Affected Rows: 0 + +INSERT INTO cast_time_index_filter_pushdown VALUES + ('2023-06-12 01:04:49.999999999'::TIMESTAMP_NS, 1), + ('2023-06-12 01:04:50.000000123'::TIMESTAMP_NS, 2), + ('2023-06-12 01:04:50.999999999'::TIMESTAMP_NS, 3), + ('2023-06-12 01:04:51.000000000'::TIMESTAMP_NS, 4); + +Affected Rows: 4 + +ADMIN FLUSH_TABLE ('cast_time_index_filter_pushdown'); + ++------------------------------------------------------+ +| ADMIN FLUSH_TABLE('cast_time_index_filter_pushdown') | ++------------------------------------------------------+ +| 0 | ++------------------------------------------------------+ + +-- SQLNESS REPLACE (-+) - +-- SQLNESS REPLACE (\s\s+) _ +-- SQLNESS REPLACE (peers.*) REDACTED +-- SQLNESS REPLACE (metrics.*) REDACTED +-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED +-- SQLNESS REPLACE num_ranges=\d+ num_ranges=REDACTED +-- SQLNESS REPLACE (RepartitionExec:.*) RepartitionExec: REDACTED +-- SQLNESS REPLACE "flat_format":\s\w+, "flat_format": REDACTED, +-- SQLNESS REPLACE (files.*) REDACTED +EXPLAIN ANALYZE VERBOSE +SELECT ts_ms, val +FROM ( + SELECT ts::TIMESTAMP_MS AS ts_ms, val + FROM cast_time_index_filter_pushdown +) projected +WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS +ORDER BY val; + ++-+-+-+ +| stage | node | plan_| ++-+-+-+ +| 0_| 0_|_SortPreservingMergeExec: [val@1 ASC NULLS LAST] REDACTED +|_|_|_SortExec: expr=[val@1 ASC NULLS LAST], preserve_partitioning=[true] REDACTED +|_|_|_ProjectionExec: expr=[CAST(ts@0 AS Timestamp(ms)) as ts_ms, val@1 as val] REDACTED +|_|_|_FilterExec: ts@0 = 1686531890000000000 REDACTED +|_|_|_MergeScanExec: REDACTED +|_|_|_| +| 1_| 0_|_CooperativeExec REDACTED +|_|_|_UnorderedScan: region=REDACTED, {"partition_count":{"count":1, "mem_ranges":0, "REDACTED +|_|_|_| +|_|_| Total rows: 0_| ++-+-+-+ + +SELECT ts_ms, val +FROM ( + SELECT ts::TIMESTAMP_MS AS ts_ms, val + FROM cast_time_index_filter_pushdown +) projected +WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS +ORDER BY val; + +++ +++ + +DROP TABLE cast_time_index_filter_pushdown; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/optimizer/cast_time_index_filter_pushdown.sql b/tests/cases/standalone/optimizer/cast_time_index_filter_pushdown.sql new file mode 100644 index 000000000000..7408668501c5 --- /dev/null +++ b/tests/cases/standalone/optimizer/cast_time_index_filter_pushdown.sql @@ -0,0 +1,46 @@ +-- Corresponding to issue #7913. +-- Verify a filter over a projected millisecond cast of a non-ms time index +-- is passed down to scan as a casted time-index predicate for pruning. + +CREATE TABLE cast_time_index_filter_pushdown ( + ts TIMESTAMP_NS NOT NULL TIME INDEX, + val BIGINT, +) ENGINE = mito +WITH + (append_mode = 'true', sst_format = 'flat'); + +INSERT INTO cast_time_index_filter_pushdown VALUES + ('2023-06-12 01:04:49.999999999'::TIMESTAMP_NS, 1), + ('2023-06-12 01:04:50.000000123'::TIMESTAMP_NS, 2), + ('2023-06-12 01:04:50.999999999'::TIMESTAMP_NS, 3), + ('2023-06-12 01:04:51.000000000'::TIMESTAMP_NS, 4); + +ADMIN FLUSH_TABLE ('cast_time_index_filter_pushdown'); + +-- SQLNESS REPLACE (-+) - +-- SQLNESS REPLACE (\s\s+) _ +-- SQLNESS REPLACE (peers.*) REDACTED +-- SQLNESS REPLACE (metrics.*) REDACTED +-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED +-- SQLNESS REPLACE num_ranges=\d+ num_ranges=REDACTED +-- SQLNESS REPLACE (RepartitionExec:.*) RepartitionExec: REDACTED +-- SQLNESS REPLACE "flat_format":\s\w+, "flat_format": REDACTED, +-- SQLNESS REPLACE (files.*) REDACTED +EXPLAIN ANALYZE VERBOSE +SELECT ts_ms, val +FROM ( + SELECT ts::TIMESTAMP_MS AS ts_ms, val + FROM cast_time_index_filter_pushdown +) projected +WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS +ORDER BY val; + +SELECT ts_ms, val +FROM ( + SELECT ts::TIMESTAMP_MS AS ts_ms, val + FROM cast_time_index_filter_pushdown +) projected +WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS +ORDER BY val; + +DROP TABLE cast_time_index_filter_pushdown;