diff --git a/src/abstract_tree.rs b/src/abstract_tree.rs index 55a0c88dd..7a24b7754 100644 --- a/src/abstract_tree.rs +++ b/src/abstract_tree.rs @@ -122,6 +122,7 @@ pub trait AbstractTree: sealed::Sealed { .iter() .map(|mt| mt.iter().map(Ok)) .collect::>(), + self.tree_config().comparator.clone(), ); // RT suppression is not needed here: flush writes both entries and RTs // to the output tables. Suppression happens at read time, not write time. diff --git a/src/blob_tree/ingest.rs b/src/blob_tree/ingest.rs index 92795c6bf..e0d4ff079 100644 --- a/src/blob_tree/ingest.rs +++ b/src/blob_tree/ingest.rs @@ -230,6 +230,7 @@ impl<'a> BlobIngestion<'a> { index.config.descriptor_table.clone(), false, false, + index.config.comparator.clone(), #[cfg(feature = "metrics")] index.metrics.clone(), ) diff --git a/src/blob_tree/mod.rs b/src/blob_tree/mod.rs index 681d58447..5e4d05d5f 100644 --- a/src/blob_tree/mod.rs +++ b/src/blob_tree/mod.rs @@ -171,7 +171,12 @@ impl BlobTree { key: &[u8], seqno: SeqNo, ) -> crate::Result> { - let Some(item) = crate::Tree::get_internal_entry_from_version(super_version, key, seqno)? + let Some(item) = crate::Tree::get_internal_entry_from_version( + super_version, + key, + seqno, + self.index.config.comparator.as_ref(), + )? else { return Ok(None); }; @@ -256,6 +261,7 @@ impl AbstractTree for BlobTree { seqno, index, None, // BlobTree does not use merge operators for prefix scans + self.index.config.comparator.clone(), prefix_hash, ) .map(move |kv| { @@ -278,14 +284,21 @@ impl AbstractTree for BlobTree { let tree = self.clone(); Box::new( - crate::Tree::create_internal_range(super_version.clone(), &range, seqno, index, None) - .map(move |kv| { - IterGuardImpl::Blob(Guard { - tree: tree.clone(), - version: super_version.version.clone(), - kv, - }) - }), + crate::Tree::create_internal_range( + super_version.clone(), + &range, + seqno, + index, + None, + self.index.config.comparator.clone(), + ) + .map(move |kv| { + IterGuardImpl::Blob(Guard { + tree: tree.clone(), + version: super_version.version.clone(), + kv, + }) + }), ) } @@ -519,6 +532,7 @@ impl AbstractTree for BlobTree { self.index.config.descriptor_table.clone(), pin_filter, pin_index, + self.index.config.comparator.clone(), #[cfg(feature = "metrics")] self.index.metrics.clone(), ) diff --git a/src/compaction/flavour.rs b/src/compaction/flavour.rs index 2bd10137c..1f5e9186e 100644 --- a/src/compaction/flavour.rs +++ b/src/compaction/flavour.rs @@ -376,6 +376,7 @@ impl StandardCompaction { opts.config.descriptor_table.clone(), pin_filter, pin_index, + opts.config.comparator.clone(), #[cfg(feature = "metrics")] opts.metrics.clone(), ) diff --git a/src/compaction/worker.rs b/src/compaction/worker.rs index 7ab9a6037..2804d326b 100644 --- a/src/compaction/worker.rs +++ b/src/compaction/worker.rs @@ -152,6 +152,7 @@ fn create_compaction_stream<'a>( to_compact: &[TableId], eviction_seqno: SeqNo, merge_operator: Option>, + comparator: crate::comparator::SharedComparator, ) -> crate::Result>>>> { let mut readers: Vec> = vec![]; let mut found = 0; @@ -178,7 +179,7 @@ fn create_compaction_stream<'a>( Ok(if found == to_compact.len() { Some( - CompactionStream::new(Merger::new(readers), eviction_seqno) + CompactionStream::new(Merger::new(readers, comparator), eviction_seqno) .with_merge_operator(merge_operator), ) } else { @@ -390,6 +391,7 @@ fn merge_tables( &payload.table_ids.iter().copied().collect::>(), opts.mvcc_gc_watermark, opts.config.merge_operator.clone(), + opts.config.comparator.clone(), )? else { log::warn!( @@ -704,7 +706,14 @@ mod tests { tree.insert("a", "a", 0); tree.flush_active_memtable(0)?; - assert!(create_compaction_stream(&tree.current_version(), &[666], 0, None)?.is_none()); + assert!(create_compaction_stream( + &tree.current_version(), + &[666], + 0, + None, + crate::comparator::default_comparator() + )? + .is_none()); Ok(()) } diff --git a/src/comparator.rs b/src/comparator.rs new file mode 100644 index 000000000..b4b5d023d --- /dev/null +++ b/src/comparator.rs @@ -0,0 +1,104 @@ +// Copyright (c) 2024-present, fjall-rs +// This source code is licensed under both the Apache 2.0 and MIT License +// (found in the LICENSE-* files in the repository) + +use std::sync::Arc; + +/// Trait for custom user key comparison. +/// +/// Comparators must be safe across unwind boundaries since they are stored +/// in tree structures that may be referenced inside `catch_unwind` blocks. +/// +/// Implementations must define a **strict total order** suitable for use in +/// sorted data structures (memtable skip list, SST block index, merge heap). +/// Specifically: +/// +/// - **Totality**: for all `a`, `b`, exactly one of `Less`, `Equal`, `Greater` holds +/// - **Transitivity**: `a < b` and `b < c` implies `a < c` +/// - **Antisymmetry**: `compare(a, b) == Less` iff `compare(b, a) == Greater` +/// - **Reflexivity**: `compare(a, a) == Equal` +/// +/// - **Bytewise equality**: `compare(a, b) == Equal` **must** imply `a == b` +/// byte-for-byte. Bloom filters and hash indexes operate on raw bytes; +/// if two byte-different keys compare as equal, hash-based lookups will +/// produce false negatives. +/// +/// Violating these invariants corrupts the sort order and produces incorrect +/// query results. +/// +/// # Important +/// +/// Once a tree is created with a comparator, it must always be opened with the +/// same comparator. Using a different comparator on an existing tree will produce +/// incorrect results. +/// +/// # Examples +/// +/// ``` +/// use lsm_tree::UserComparator; +/// use std::cmp::Ordering; +/// +/// /// Comparator that orders u64 keys stored as big-endian bytes. +/// struct U64Comparator; +/// +/// impl UserComparator for U64Comparator { +/// fn compare(&self, a: &[u8], b: &[u8]) -> Ordering { +/// if a.len() == 8 && b.len() == 8 { +/// // Length checked, conversion cannot fail. +/// let a_u64 = u64::from_be_bytes(a.try_into().unwrap()); +/// let b_u64 = u64::from_be_bytes(b.try_into().unwrap()); +/// a_u64.cmp(&b_u64) +/// } else { +/// // Non-8-byte keys: fall back to lexicographic ordering +/// // to preserve the bytewise-equality invariant. +/// a.cmp(b) +/// } +/// } +/// } +/// ``` +pub trait UserComparator: Send + Sync + std::panic::RefUnwindSafe + 'static { + /// Compares two user keys, returning their ordering. + fn compare(&self, a: &[u8], b: &[u8]) -> std::cmp::Ordering; + + /// Returns `true` if this comparator is lexicographic byte ordering. + /// + /// When `true`, internal optimizations can avoid allocations in + /// prefix-compressed block comparisons. Override only if your + /// comparator is truly equivalent to `a.cmp(b)` on raw bytes. + fn is_lexicographic(&self) -> bool { + false + } +} + +/// Default comparator using lexicographic byte ordering. +/// +/// This is the comparator used when no custom comparator is configured, +/// preserving backward compatibility with existing trees. +#[derive(Clone, Debug)] +pub struct DefaultUserComparator; + +impl UserComparator for DefaultUserComparator { + #[inline] + fn compare(&self, a: &[u8], b: &[u8]) -> std::cmp::Ordering { + a.cmp(b) + } + + #[inline] + fn is_lexicographic(&self) -> bool { + true + } +} + +/// Shared reference to a [`UserComparator`]. +pub type SharedComparator = Arc; + +/// Returns the default comparator (lexicographic byte ordering). +/// +/// Uses a shared static instance to avoid repeated allocations. +#[must_use] +pub fn default_comparator() -> SharedComparator { + // LazyLock creates the Arc once; subsequent calls just clone the Arc (ref-count bump). + static DEFAULT: std::sync::LazyLock = + std::sync::LazyLock::new(|| Arc::new(DefaultUserComparator)); + DEFAULT.clone() +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 260eade11..6a9b565aa 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -20,9 +20,14 @@ pub use restart_interval::RestartIntervalPolicy; pub type PartitioningPolicy = PinningPolicy; use crate::{ - compaction::filter::Factory, merge_operator::MergeOperator, path::absolute_path, - prefix::PrefixExtractor, version::DEFAULT_LEVEL_COUNT, AnyTree, BlobTree, Cache, - CompressionType, DescriptorTable, SequenceNumberCounter, SharedSequenceNumberGenerator, Tree, + compaction::filter::Factory, + comparator::{self, SharedComparator}, + merge_operator::MergeOperator, + path::absolute_path, + prefix::PrefixExtractor, + version::DEFAULT_LEVEL_COUNT, + AnyTree, BlobTree, Cache, CompressionType, DescriptorTable, SequenceNumberCounter, + SharedSequenceNumberGenerator, Tree, }; use std::{ path::{Path, PathBuf}, @@ -246,6 +251,15 @@ pub struct Config { #[doc(hidden)] pub kv_separation_opts: Option, + /// Custom user key comparator. + /// + /// When set, all key comparisons use this comparator instead of the + /// default lexicographic byte ordering. Once a tree is opened with a + /// comparator, it must always be re-opened with the same comparator. + // Not `pub` — use `Config::comparator()` builder method as the public API. + #[doc(hidden)] + pub(crate) comparator: SharedComparator, + /// The global sequence number generator /// /// Should be shared between multiple trees of a database @@ -311,6 +325,8 @@ impl Default for Config { expect_point_read_hits: false, kv_separation_opts: None, + + comparator: comparator::default_comparator(), } } } @@ -522,6 +538,27 @@ impl Config { self } + /// Sets a custom user key comparator. + /// + /// When configured, all key ordering (memtable, block index, merge, + /// range scans) uses this comparator instead of the default lexicographic + /// byte ordering. + /// + /// # Important + /// + /// Once a tree is created with a custom comparator, it **must** be + /// re-opened with the same comparator. Using a different comparator + /// on an existing tree produces incorrect results. + /// + /// The comparator identity is **not** persisted to disk — the caller + /// is responsible for ensuring the same comparator is used across + /// open/close cycles (same approach as `RocksDB`). + #[must_use] + pub fn comparator(mut self, comparator: SharedComparator) -> Self { + self.comparator = comparator; + self + } + /// Opens a tree using the config. /// /// # Errors diff --git a/src/key.rs b/src/key.rs index 42fc626aa..b864135fe 100644 --- a/src/key.rs +++ b/src/key.rs @@ -2,7 +2,7 @@ // This source code is licensed under both the Apache 2.0 and MIT License // (found in the LICENSE-* files in the repository) -use crate::{SeqNo, UserKey, ValueType}; +use crate::{comparator::UserComparator, SeqNo, UserKey, ValueType}; use std::cmp::Reverse; #[derive(Clone, Eq)] @@ -56,6 +56,20 @@ impl InternalKey { pub fn is_tombstone(&self) -> bool { self.value_type.is_tombstone() } + + /// Compares two internal keys using a custom user key comparator. + /// + /// User keys are compared via the given comparator; ties are broken + /// by sequence number in descending order (higher seqno = "smaller" + /// in sort order), matching the invariant of [`Ord for InternalKey`]. + pub(crate) fn compare_with( + &self, + other: &Self, + cmp: &dyn UserComparator, + ) -> std::cmp::Ordering { + cmp.compare(&self.user_key, &other.user_key) + .then_with(|| Reverse(self.seqno).cmp(&Reverse(other.seqno))) + } } impl PartialOrd for InternalKey { diff --git a/src/lib.rs b/src/lib.rs index 7eea1ac1a..119ef569e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,6 +72,8 @@ mod abstract_tree; #[doc(hidden)] pub mod blob_tree; +mod comparator; + #[doc(hidden)] mod cache; @@ -185,6 +187,7 @@ pub use { any_tree::AnyTree, blob_tree::BlobTree, cache::Cache, + comparator::{DefaultUserComparator, SharedComparator, UserComparator}, compression::CompressionType, config::{Config, KvSeparationOptions, TreeType}, descriptor_table::DescriptorTable, diff --git a/src/memtable/mod.rs b/src/memtable/mod.rs index 433ee24e6..871ec9acf 100644 --- a/src/memtable/mod.rs +++ b/src/memtable/mod.rs @@ -4,6 +4,7 @@ pub mod interval_tree; +use crate::comparator::SharedComparator; use crate::key::InternalKey; use crate::range_tombstone::RangeTombstone; use crate::{ @@ -11,12 +12,56 @@ use crate::{ UserKey, ValueType, }; use crossbeam_skiplist::SkipMap; -use std::ops::RangeBounds; +use std::ops::Bound; use std::sync::atomic::{AtomicBool, AtomicU64}; use std::sync::RwLock; pub use crate::tree::inner::MemtableId; +/// Wrapper around [`InternalKey`] that uses a custom [`UserComparator`] for ordering. +/// +/// This wrapper is used as the key type in the memtable's `SkipMap` to support +/// pluggable key comparison. The `SharedComparator` is cloned (Arc bump) per entry. +#[derive(Clone)] +pub struct MemtableKey { + pub(crate) inner: InternalKey, + pub(crate) comparator: SharedComparator, +} + +impl MemtableKey { + pub(crate) fn new(inner: InternalKey, comparator: SharedComparator) -> Self { + Self { inner, comparator } + } +} + +impl PartialEq for MemtableKey { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == std::cmp::Ordering::Equal + } +} + +impl Eq for MemtableKey {} + +impl PartialOrd for MemtableKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MemtableKey { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.inner + .compare_with(&other.inner, self.comparator.as_ref()) + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +impl std::fmt::Debug for MemtableKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.fmt(f) + } +} + /// The memtable serves as an intermediary, ephemeral, sorted storage for new items /// /// When the Memtable exceeds some size, it should be flushed to a table. @@ -24,9 +69,12 @@ pub struct Memtable { #[doc(hidden)] pub id: MemtableId, + /// The user key comparator used for ordering entries. + pub(crate) comparator: SharedComparator, + /// The actual content, stored in a lock-free skiplist. #[doc(hidden)] - pub items: SkipMap, + pub(crate) items: SkipMap, /// Range tombstones stored in an interval tree. /// @@ -41,6 +89,11 @@ pub struct Memtable { /// starvation is not a concern here: range deletes are rare, the write-side /// critical section is O(log n) with n typically small, and the memtable /// rotates (becoming read-only) well before contention could accumulate. + // NOTE: The interval tree uses lexicographic `Ord` on `UserKey` for + // containment queries. With a custom comparator, RT suppression in + // the memtable may produce incorrect results for non-lexicographic + // orderings. Threading the comparator into the AVL tree is tracked + // as a follow-up issue. pub(crate) range_tombstones: RwLock, /// Approximate active memtable size. @@ -74,11 +127,16 @@ impl Memtable { .store(true, std::sync::atomic::Ordering::Relaxed); } + // `pub` + `#[doc(hidden)]`: used by the host crate (fjall) to construct + // ephemeral memtables. Not part of the semver-stable API. + // The comparator parameter is mandatory because memtable ordering must + // match the tree's comparator; a default would silently produce wrong order. #[doc(hidden)] #[must_use] - pub fn new(id: MemtableId) -> Self { + pub fn new(id: MemtableId, comparator: SharedComparator) -> Self { Self { id, + comparator, items: SkipMap::default(), range_tombstones: RwLock::new(interval_tree::IntervalTree::new()), approximate_size: AtomicU64::default(), @@ -90,22 +148,33 @@ impl Memtable { /// Creates an iterator over all items. pub fn iter(&self) -> impl DoubleEndedIterator + '_ { self.items.iter().map(|entry| InternalValue { - key: entry.key().clone(), + key: entry.key().inner.clone(), value: entry.value().clone(), }) } /// Creates an iterator over a range of items. - pub(crate) fn range<'a, R: RangeBounds + 'a>( - &'a self, - range: R, - ) -> impl DoubleEndedIterator + 'a { - self.items.range(range).map(|entry| InternalValue { - key: entry.key().clone(), + /// + /// Accepts `InternalKey`-based bounds and wraps them with the memtable's comparator. + pub(crate) fn range_internal( + &self, + range: (Bound, Bound), + ) -> impl DoubleEndedIterator + '_ { + let wrapped = ( + range.0.map(|k| self.wrap_key(k)), + range.1.map(|k| self.wrap_key(k)), + ); + self.items.range(wrapped).map(|entry| InternalValue { + key: entry.key().inner.clone(), value: entry.value().clone(), }) } + /// Wraps an `InternalKey` with this memtable's comparator for `SkipMap` lookups. + pub(crate) fn wrap_key(&self, key: InternalKey) -> MemtableKey { + MemtableKey::new(key, self.comparator.clone()) + } + /// Returns the item by key if it exists. /// /// The item with the highest seqno will be returned, if `seqno` is None. @@ -131,15 +200,16 @@ impl Memtable { // abcdef -> 6 // abcdef -> 5 // - let lower_bound = InternalKey::new(key, seqno - 1, ValueType::Value); + let lower_bound = self.wrap_key(InternalKey::new(key, seqno - 1, ValueType::Value)); - let mut iter = self - .items - .range(lower_bound..) - .take_while(|entry| &*entry.key().user_key == key); + let cmp = self.comparator.as_ref(); + + let mut iter = self.items.range(lower_bound..).take_while(|entry| { + cmp.compare(&entry.key().inner.user_key, key) == std::cmp::Ordering::Equal + }); iter.next().map(|entry| InternalValue { - key: entry.key().clone(), + key: entry.key().inner.clone(), value: entry.value().clone(), }) } @@ -168,17 +238,21 @@ impl Memtable { clippy::expect_used, reason = "keys are limited to 16-bit length + values are limited to 32-bit length" )] - let item_size = - (item.key.user_key.len() + item.value.len() + std::mem::size_of::()) - .try_into() - .expect("should fit into u64"); + // Account for MemtableKey overhead (InternalKey + Arc) + let item_size = (item.key.user_key.len() + + item.value.len() + + std::mem::size_of::() + + std::mem::size_of::()) + .try_into() + .expect("should fit into u64"); let size_before = self .approximate_size .fetch_add(item_size, std::sync::atomic::Ordering::AcqRel); let key = InternalKey::new(item.key.user_key, item.key.seqno, item.key.value_type); - self.items.insert(key, item.value); + let memtable_key = MemtableKey::new(key, self.comparator.clone()); + self.items.insert(memtable_key, item.value); self.highest_seqno .fetch_max(item.key.seqno, std::sync::atomic::Ordering::AcqRel); @@ -310,17 +384,22 @@ impl Memtable { #[cfg(test)] mod tests { use super::*; + use crate::comparator::default_comparator; use crate::ValueType; use std::sync::{Arc, Barrier}; use test_log::test; + fn new_memtable(id: MemtableId) -> Memtable { + Memtable::new(id, default_comparator()) + } + #[test] #[expect( clippy::expect_used, reason = "tests use expect for lock and thread join" )] fn rwlock_read_while_read_held_succeeds() { - let mt = Memtable::new(0); + let mt = new_memtable(0); let _ = mt.insert_range_tombstone(b"a".to_vec().into(), b"z".to_vec().into(), 10); // Two one-way channels avoid Barrier entirely — if either side @@ -352,7 +431,7 @@ mod tests { #[test] #[expect(clippy::expect_used, reason = "tests use expect for thread join")] fn suppression_queries_concurrent_readers_no_panic() { - let mt = Arc::new(Memtable::new(0)); + let mt = Arc::new(new_memtable(0)); let _ = mt.insert_range_tombstone(b"a".to_vec().into(), b"z".to_vec().into(), 10); for i in 0u8..100 { @@ -386,7 +465,7 @@ mod tests { #[test] #[expect(clippy::expect_used, reason = "tests use expect for thread join")] fn range_tombstones_concurrent_read_write_writers_observable() { - let mt = Arc::new(Memtable::new(0)); + let mt = Arc::new(new_memtable(0)); // Barrier ensures all 6 threads start simultaneously. let start = Arc::new(Barrier::new(6)); @@ -447,7 +526,7 @@ mod tests { #[test] #[expect(clippy::expect_used, reason = "tests use expect for thread join")] fn range_tombstones_populated_tree_concurrent_reads_succeed() { - let mt = Arc::new(Memtable::new(0)); + let mt = Arc::new(new_memtable(0)); for i in 0u8..50 { let start = vec![b'a' + (i % 25)]; @@ -478,7 +557,7 @@ mod tests { #[test] #[expect(clippy::unwrap_used)] fn memtable_mvcc_point_read() { - let memtable = Memtable::new(0); + let memtable = new_memtable(0); memtable.insert(InternalValue::from_components( *b"hello-key-999991", @@ -521,7 +600,7 @@ mod tests { #[test] fn memtable_get() { - let memtable = Memtable::new(0); + let memtable = new_memtable(0); let value = InternalValue::from_components(b"abc".to_vec(), b"abc".to_vec(), 0, ValueType::Value); @@ -533,7 +612,7 @@ mod tests { #[test] fn memtable_get_highest_seqno() { - let memtable = Memtable::new(0); + let memtable = new_memtable(0); memtable.insert(InternalValue::from_components( b"abc".to_vec(), @@ -579,7 +658,7 @@ mod tests { #[test] fn memtable_get_prefix() { - let memtable = Memtable::new(0); + let memtable = new_memtable(0); memtable.insert(InternalValue::from_components( b"abc0".to_vec(), @@ -617,7 +696,7 @@ mod tests { #[test] fn memtable_get_old_version() { - let memtable = Memtable::new(0); + let memtable = new_memtable(0); memtable.insert(InternalValue::from_components( b"abc".to_vec(), diff --git a/src/merge.rs b/src/merge.rs index a3d1a4116..395158f8d 100644 --- a/src/merge.rs +++ b/src/merge.rs @@ -2,6 +2,7 @@ // This source code is licensed under both the Apache 2.0 and MIT License // (found in the LICENSE-* files in the repository) +use crate::comparator::SharedComparator; use crate::InternalValue; use interval_heap::IntervalHeap as Heap; @@ -9,19 +10,21 @@ type IterItem = crate::Result; pub type BoxedIterator<'a> = Box + Send + 'a>; -struct HeapItem(usize, InternalValue); +// Arc clone per heap entry is an atomic ref-count bump. The heap holds at most +// one entry per source iterator (typically <10), so the overhead is negligible. +struct HeapItem(usize, InternalValue, SharedComparator); impl Eq for HeapItem {} impl PartialEq for HeapItem { fn eq(&self, other: &Self) -> bool { - self.1.key == other.1.key + self.1.key.compare_with(&other.1.key, self.2.as_ref()) == std::cmp::Ordering::Equal } } impl Ord for HeapItem { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.1.key.cmp(&other.1.key) + self.1.key.compare_with(&other.1.key, self.2.as_ref()) } } @@ -37,11 +40,12 @@ pub struct Merger { heap: Heap, initialized_lo: bool, initialized_hi: bool, + comparator: SharedComparator, } impl> Merger { #[must_use] - pub fn new(iterators: Vec) -> Self { + pub fn new(iterators: Vec, comparator: SharedComparator) -> Self { let heap = Heap::with_capacity(iterators.len()); let iterators = iterators.into_iter().collect::>(); @@ -51,6 +55,7 @@ impl> Merger { heap, initialized_lo: false, initialized_hi: false, + comparator, } } @@ -58,7 +63,7 @@ impl> Merger { for (idx, it) in self.iterators.iter_mut().enumerate() { if let Some(item) = it.next() { let item = item?; - self.heap.push(HeapItem(idx, item)); + self.heap.push(HeapItem(idx, item, self.comparator.clone())); } } self.initialized_lo = true; @@ -71,7 +76,7 @@ impl> Merger { for (idx, it) in self.iterators.iter_mut().enumerate() { if let Some(item) = it.next_back() { let item = item?; - self.heap.push(HeapItem(idx, item)); + self.heap.push(HeapItem(idx, item, self.comparator.clone())); } } self.initialized_hi = true; @@ -92,7 +97,8 @@ impl> Iterator for Merger { #[expect(clippy::indexing_slicing, reason = "we trust the HeapItem index")] if let Some(next_item) = self.iterators[min_item.0].next() { let next_item = fail_iter!(next_item); - self.heap.push(HeapItem(min_item.0, next_item)); + self.heap + .push(HeapItem(min_item.0, next_item, self.comparator.clone())); } Some(Ok(min_item.1)) @@ -110,7 +116,8 @@ impl> DoubleEndedIterator for Merger #[expect(clippy::indexing_slicing, reason = "we trust the HeapItem index")] if let Some(next_item) = self.iterators[max_item.0].next_back() { let next_item = fail_iter!(next_item); - self.heap.push(HeapItem(max_item.0, next_item)); + self.heap + .push(HeapItem(max_item.0, next_item, self.comparator.clone())); } Some(Ok(max_item.1)) @@ -120,6 +127,7 @@ impl> DoubleEndedIterator for Merger #[cfg(test)] mod tests { use super::*; + use crate::comparator; use crate::ValueType::Value; use test_log::test; @@ -135,7 +143,10 @@ mod tests { Ok(InternalValue::from_components("b", b"", 0, Value)), ]; - let mut iter = Merger::new(vec![a.into_iter(), b.into_iter()]); + let mut iter = Merger::new( + vec![a.into_iter(), b.into_iter()], + comparator::default_comparator(), + ); assert_eq!( iter.next().unwrap()?, @@ -163,7 +174,10 @@ mod tests { Ok(InternalValue::from_components("a", b"", 0, Value)), ]; - let mut iter = Merger::new(vec![a.into_iter(), b.into_iter()]); + let mut iter = Merger::new( + vec![a.into_iter(), b.into_iter()], + comparator::default_comparator(), + ); assert_eq!( iter.next().unwrap()?, diff --git a/src/range.rs b/src/range.rs index d77f9b90d..704bb0999 100644 --- a/src/range.rs +++ b/src/range.rs @@ -73,6 +73,9 @@ pub struct IterState { pub(crate) ephemeral: Option<(Arc, SeqNo)>, pub(crate) merge_operator: Option>, + /// User key comparator for merge ordering. + pub(crate) comparator: crate::comparator::SharedComparator, + /// Optional prefix hash for prefix bloom filter skipping. /// /// When set, segments whose bloom filter reports no match for this @@ -414,7 +417,7 @@ impl TreeIter { .map(|rt| (rt, seqno)), ); - let iter = memtable.range(range.clone()); + let iter = memtable.range_internal(range.clone()); iters.push(Box::new( iter.filter(move |item| seqno_filter(item.key.seqno, seqno)) @@ -433,7 +436,7 @@ impl TreeIter { .map(|rt| (rt, seqno)), ); - let iter = lock.version.active_memtable.range(range.clone()); + let iter = lock.version.active_memtable.range_internal(range.clone()); iters.push(Box::new( iter.filter(move |item| seqno_filter(item.key.seqno, seqno)) @@ -450,14 +453,14 @@ impl TreeIter { ); let iter = Box::new( - mt.range(range) + mt.range_internal(range) .filter(move |item| seqno_filter(item.key.seqno, *eph_seqno)) .map(Ok), ); iters.push(iter); } - let merged = Merger::new(iters); + let merged = Merger::new(iters, lock.comparator.clone()); // Clone needed: MvccStream uses the RT set for merge suppression, // while RangeTombstoneFilter below consumes it for post-merge // filtering. An Arc<[_]> could avoid the copy if RT sets grow large. diff --git a/src/range_tombstone.rs b/src/range_tombstone.rs index 4861f86a2..41c2493f0 100644 --- a/src/range_tombstone.rs +++ b/src/range_tombstone.rs @@ -26,9 +26,12 @@ impl RangeTombstone { /// /// Debug-asserts that `start < end`. Callers must validate untrusted input /// before constructing a `RangeTombstone`. + // No debug_assert on start < end here: with custom comparators, + // lexicographic order may differ from comparator order. Callers + // (decode_range_tombstones, insert_range_tombstone) validate using + // the appropriate comparator or lexicographic order at their level. #[must_use] pub fn new(start: UserKey, end: UserKey, seqno: SeqNo) -> Self { - debug_assert!(start < end, "range tombstone start must be < end"); Self { start, end, seqno } } diff --git a/src/table/block/decoder.rs b/src/table/block/decoder.rs index 1b4e3e01d..e16ceaca6 100644 --- a/src/table/block/decoder.rs +++ b/src/table/block/decoder.rs @@ -14,10 +14,15 @@ use std::{io::Cursor, marker::PhantomData}; /// /// Parsed items only hold references to their keys and values, use `materialize` to create an owned value. pub trait ParsedItem { - /// Compares this item's key with a needle. + /// Compares this item's key with a needle using the given comparator. /// /// We can not access the key directly because it may be comprised of prefix + suffix. - fn compare_key(&self, needle: &[u8], bytes: &[u8]) -> std::cmp::Ordering; + fn compare_key( + &self, + needle: &[u8], + bytes: &[u8], + cmp: &dyn crate::comparator::UserComparator, + ) -> std::cmp::Ordering; /// Returns the byte offset of the key's start position. fn key_offset(&self) -> usize; diff --git a/src/table/block_index/full.rs b/src/table/block_index/full.rs index 992b429cc..a4b35a27c 100644 --- a/src/table/block_index/full.rs +++ b/src/table/block_index/full.rs @@ -2,6 +2,7 @@ // This source code is licensed under both the Apache 2.0 and MIT License // (found in the LICENSE-* files in the repository) +use crate::comparator::SharedComparator; use crate::table::block_index::{iter::OwnedIndexBlockIter, BlockIndexIter}; use crate::table::{IndexBlock, KeyedBlockHandle}; use crate::SeqNo; @@ -9,15 +10,18 @@ use crate::SeqNo; /// Index that translates item keys to data block handles /// /// The index is fully loaded into memory. -pub struct FullBlockIndex(IndexBlock); +pub struct FullBlockIndex { + block: IndexBlock, + comparator: SharedComparator, +} impl FullBlockIndex { - pub fn new(block: IndexBlock) -> Self { - Self(block) + pub fn new(block: IndexBlock, comparator: SharedComparator) -> Self { + Self { block, comparator } } pub fn inner(&self) -> &IndexBlock { - &self.0 + &self.block } pub fn forward_reader(&self, needle: &[u8], seqno: SeqNo) -> Option { @@ -30,7 +34,10 @@ impl FullBlockIndex { } pub fn iter(&self) -> Iter { - Iter(OwnedIndexBlockIter::new(self.0.clone(), IndexBlock::iter)) + let cmp = self.comparator.clone(); + Iter(OwnedIndexBlockIter::new(self.block.clone(), |b| { + b.iter(cmp) + })) } } diff --git a/src/table/block_index/two_level.rs b/src/table/block_index/two_level.rs index ff13ecb4a..fceca8d8d 100644 --- a/src/table/block_index/two_level.rs +++ b/src/table/block_index/two_level.rs @@ -2,6 +2,7 @@ // This source code is licensed under both the Apache 2.0 and MIT License // (found in the LICENSE-* files in the repository) +use crate::comparator::SharedComparator; use crate::file_accessor::FileAccessor; use crate::table::{IndexBlock, KeyedBlockHandle}; use crate::SeqNo; @@ -28,6 +29,7 @@ pub struct TwoLevelBlockIndex { pub(crate) file_accessor: FileAccessor, pub(crate) cache: Arc, pub(crate) compression: CompressionType, + pub(crate) comparator: SharedComparator, #[cfg(feature = "metrics")] pub(crate) metrics: Arc, @@ -47,6 +49,7 @@ impl TwoLevelBlockIndex { file_accessor: self.file_accessor.clone(), cache: self.cache.clone(), compression: self.compression, + comparator: self.comparator.clone(), #[cfg(feature = "metrics")] metrics: self.metrics.clone(), @@ -69,6 +72,7 @@ pub struct Iter { file_accessor: FileAccessor, cache: Arc, compression: CompressionType, + comparator: SharedComparator, #[cfg(feature = "metrics")] metrics: Arc, @@ -76,7 +80,8 @@ pub struct Iter { impl Iter { fn init_tli(&mut self) -> bool { - let mut iter = OwnedIndexBlockIter::new(self.tli_block.clone(), IndexBlock::iter); + let cmp = self.comparator.clone(); + let mut iter = OwnedIndexBlockIter::new(self.tli_block.clone(), |b| b.iter(cmp)); if let Some((lo_key, lo_seqno)) = &self.lo { if !iter.seek_lower(lo_key, *lo_seqno) { @@ -138,7 +143,8 @@ impl Iterator for Iter { )); let index_block = IndexBlock::new(block); - let mut iter = OwnedIndexBlockIter::new(index_block, IndexBlock::iter); + let cmp = self.comparator.clone(); + let mut iter = OwnedIndexBlockIter::new(index_block, |b| b.iter(cmp)); if let Some((lo_key, lo_seqno)) = &self.lo { if !iter.seek_lower(lo_key, *lo_seqno) { @@ -201,7 +207,8 @@ impl DoubleEndedIterator for Iter { )); let index_block = IndexBlock::new(block); - let mut iter = OwnedIndexBlockIter::new(index_block, IndexBlock::iter); + let cmp = self.comparator.clone(); + let mut iter = OwnedIndexBlockIter::new(index_block, |b| b.iter(cmp)); if let Some((lo_key, lo_seqno)) = &self.lo { if !iter.seek_lower(lo_key, *lo_seqno) { diff --git a/src/table/block_index/volatile.rs b/src/table/block_index/volatile.rs index db731c9bd..8f5c84c9e 100644 --- a/src/table/block_index/volatile.rs +++ b/src/table/block_index/volatile.rs @@ -4,6 +4,7 @@ use super::KeyedBlockHandle; use crate::{ + comparator::SharedComparator, file_accessor::FileAccessor, table::{ block::BlockType, @@ -28,6 +29,7 @@ pub struct VolatileBlockIndex { pub(crate) cache: Arc, pub(crate) handle: BlockHandle, pub(crate) compression: CompressionType, + pub(crate) comparator: SharedComparator, #[cfg(feature = "metrics")] pub(crate) metrics: Arc, @@ -53,6 +55,7 @@ pub struct Iter { cache: Arc, handle: BlockHandle, compression: CompressionType, + comparator: SharedComparator, lo: Option<(UserKey, SeqNo)>, hi: Option<(UserKey, SeqNo)>, @@ -71,6 +74,7 @@ impl Iter { cache: index.cache.clone(), handle: index.handle, compression: index.compression, + comparator: index.comparator.clone(), lo: None, hi: None, @@ -113,7 +117,8 @@ impl Iterator for Iter { )); let index_block = IndexBlock::new(block); - let mut iter = OwnedIndexBlockIter::new(index_block, IndexBlock::iter); + let cmp = self.comparator.clone(); + let mut iter = OwnedIndexBlockIter::new(index_block, |b| b.iter(cmp)); if let Some((lo_key, lo_seqno)) = &self.lo { if !iter.seek_lower(lo_key, *lo_seqno) { @@ -153,7 +158,8 @@ impl DoubleEndedIterator for Iter { )); let index_block = IndexBlock::new(block); - let mut iter = OwnedIndexBlockIter::new(index_block, IndexBlock::iter); + let cmp = self.comparator.clone(); + let mut iter = OwnedIndexBlockIter::new(index_block, |b| b.iter(cmp)); if let Some((lo_key, lo_seqno)) = &self.lo { if !iter.seek_lower(lo_key, *lo_seqno) { diff --git a/src/table/data_block/iter.rs b/src/table/data_block/iter.rs index c8b4ba663..28c6562e8 100644 --- a/src/table/data_block/iter.rs +++ b/src/table/data_block/iter.rs @@ -3,6 +3,7 @@ // (found in the LICENSE-* files in the repository) use crate::{ + comparator::SharedComparator, double_ended_peekable::{DoubleEndedPeekable, DoubleEndedPeekableExt}, table::{ block::{Decoder, ParsedItem}, @@ -16,14 +17,23 @@ pub struct Iter<'a> { bytes: &'a [u8], decoder: DoubleEndedPeekable>, + comparator: SharedComparator, } impl<'a> Iter<'a> { /// Creates a new iterator over a data block. #[must_use] - pub fn new(bytes: &'a [u8], decoder: Decoder<'a, InternalValue, DataBlockParsedItem>) -> Self { + pub fn new( + bytes: &'a [u8], + decoder: Decoder<'a, InternalValue, DataBlockParsedItem>, + comparator: SharedComparator, + ) -> Self { let decoder = decoder.double_ended_peekable(); - Self { bytes, decoder } + Self { + bytes, + decoder, + comparator, + } } /// Seek the iterator to an byte offset. @@ -45,8 +55,9 @@ impl<'a> Iter<'a> { /// `needle` will be found within roughly one restart interval of the /// resulting position. pub fn seek_to_key_seqno(&mut self, needle: &[u8], seqno: SeqNo) -> bool { + let cmp = &self.comparator; self.decoder.inner_mut().seek( - |head_key, head_seqno| match head_key.cmp(needle) { + |head_key, head_seqno| match cmp.compare(head_key, needle) { std::cmp::Ordering::Less => true, std::cmp::Ordering::Equal => head_seqno >= seqno, std::cmp::Ordering::Greater => false, @@ -71,7 +82,7 @@ impl<'a> Iter<'a> { return false; }; - match item.compare_key(needle, self.bytes) { + match item.compare_key(needle, self.bytes, self.comparator.as_ref()) { std::cmp::Ordering::Equal => { return true; } @@ -99,11 +110,11 @@ impl<'a> Iter<'a> { /// visited from the selected one toward index 0 — a tighter predicate /// would skip intervals that may contain the visible version. pub fn seek_upper(&mut self, needle: &[u8], _seqno: SeqNo) -> bool { - if !self - .decoder - .inner_mut() - .seek_upper(|head_key, _| head_key <= needle, false) - { + let cmp = &self.comparator; + if !self.decoder.inner_mut().seek_upper( + |head_key, _| cmp.compare(head_key, needle) != std::cmp::Ordering::Greater, + false, + ) { return false; } @@ -113,7 +124,7 @@ impl<'a> Iter<'a> { return false; }; - match item.compare_key(needle, self.bytes) { + match item.compare_key(needle, self.bytes, self.comparator.as_ref()) { std::cmp::Ordering::Equal => { return true; } @@ -145,7 +156,7 @@ impl<'a> Iter<'a> { return false; }; - match item.compare_key(needle, self.bytes) { + match item.compare_key(needle, self.bytes, self.comparator.as_ref()) { std::cmp::Ordering::Greater => { return true; } @@ -165,11 +176,11 @@ impl<'a> Iter<'a> { /// See [`seek_upper`] for why `seqno` is accepted but unused in reverse /// seeks. pub fn seek_upper_exclusive(&mut self, needle: &[u8], _seqno: SeqNo) -> bool { - if !self - .decoder - .inner_mut() - .seek_upper(|head_key, _| head_key <= needle, false) - { + let cmp = &self.comparator; + if !self.decoder.inner_mut().seek_upper( + |head_key, _| cmp.compare(head_key, needle) != std::cmp::Ordering::Greater, + false, + ) { return false; } @@ -178,7 +189,7 @@ impl<'a> Iter<'a> { return false; }; - match item.compare_key(needle, self.bytes) { + match item.compare_key(needle, self.bytes, self.comparator.as_ref()) { std::cmp::Ordering::Less => { return true; } diff --git a/src/table/data_block/iter_test.rs b/src/table/data_block/iter_test.rs index 1086d7b65..f792cc6c3 100644 --- a/src/table/data_block/iter_test.rs +++ b/src/table/data_block/iter_test.rs @@ -1,5 +1,6 @@ #[expect(clippy::expect_used)] mod tests { + use crate::comparator::default_comparator; use crate::{ table::{ block::{BlockType, Header, ParsedItem}, @@ -70,7 +71,7 @@ mod tests { }); { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); iter.seek(&10u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&110u64.to_be_bytes(), SeqNo::MAX); let iter = iter.map(|x| x.materialize(data_block.as_slice())); @@ -82,7 +83,8 @@ mod tests { } { - let mut iter: crate::table::data_block::Iter<'_> = data_block.iter(); + let mut iter: crate::table::data_block::Iter<'_> = + data_block.iter(default_comparator()); iter.seek(&10u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&110u64.to_be_bytes(), SeqNo::MAX); let iter = iter.map(|x| x.materialize(data_block.as_slice())); @@ -94,7 +96,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); iter.seek(&10u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&110u64.to_be_bytes(), SeqNo::MAX); @@ -144,7 +146,7 @@ mod tests { }); { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); iter.seek(&10u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&109u64.to_be_bytes(), SeqNo::MAX); let iter = iter.map(|x| x.materialize(data_block.as_slice())); @@ -156,7 +158,8 @@ mod tests { } { - let mut iter: crate::table::data_block::Iter<'_> = data_block.iter(); + let mut iter: crate::table::data_block::Iter<'_> = + data_block.iter(default_comparator()); iter.seek(&10u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&109u64.to_be_bytes(), SeqNo::MAX); let iter = iter.map(|x| x.materialize(data_block.as_slice())); @@ -168,7 +171,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); iter.seek(&10u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&109u64.to_be_bytes(), SeqNo::MAX); @@ -217,7 +220,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); iter.seek(&5u64.to_be_bytes(), SeqNo::MAX); iter.seek_upper(&9u64.to_be_bytes(), SeqNo::MAX); @@ -270,7 +273,7 @@ mod tests { }); let iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); let real_items: Vec<_> = iter.collect(); @@ -305,7 +308,7 @@ mod tests { }); let iter = data_block - .iter() + .iter(default_comparator()) .rev() .map(|item| item.materialize(&data_block.inner.data)); @@ -343,7 +346,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"d", SeqNo::MAX), "should seek"); @@ -384,7 +387,7 @@ mod tests { }); { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(!iter.seek(b"a", SeqNo::MAX), "should not seek"); @@ -396,7 +399,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(!iter.seek_upper(b"g", SeqNo::MAX), "should not seek"); @@ -408,7 +411,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"b", SeqNo::MAX), "should seek"); @@ -423,7 +426,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"f", SeqNo::MAX), "should seek"); @@ -464,7 +467,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"c", SeqNo::MAX), "should seek"); assert!(iter.seek_upper(b"d", SeqNo::MAX), "should seek"); @@ -505,7 +508,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"b", SeqNo::MAX), "should seek"); @@ -546,7 +549,7 @@ mod tests { }); { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"d", SeqNo::MAX), "should seek"); assert!(iter.seek_upper(b"d", SeqNo::MAX), "should seek"); @@ -562,7 +565,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"d", SeqNo::MAX), "should seek"); assert!(iter.seek(b"d", SeqNo::MAX), "should seek"); @@ -578,7 +581,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"d", SeqNo::MAX), "should seek"); assert!(iter.seek_upper(b"d", SeqNo::MAX), "should seek"); @@ -600,7 +603,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"d", SeqNo::MAX), "should seek"); assert!(iter.seek(b"d", SeqNo::MAX), "should seek"); @@ -649,7 +652,7 @@ mod tests { }); { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"f", SeqNo::MAX), "should seek"); iter.seek_upper(b"e", SeqNo::MAX); @@ -660,7 +663,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"f", SeqNo::MAX), "should seek"); iter.seek_upper(b"e", SeqNo::MAX); @@ -671,7 +674,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"e", SeqNo::MAX), "should seek"); iter.seek(b"f", SeqNo::MAX); @@ -682,7 +685,7 @@ mod tests { } { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"e", SeqNo::MAX), "should seek"); iter.seek(b"f", SeqNo::MAX); @@ -719,7 +722,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"b", SeqNo::MAX), "should seek correctly"); @@ -756,7 +759,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"d", SeqNo::MAX), "should seek correctly"); @@ -796,7 +799,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"f", SeqNo::MAX), "should seek correctly"); @@ -836,7 +839,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(!iter.seek(b"a", SeqNo::MAX), "should not find exact match"); @@ -873,7 +876,7 @@ mod tests { }, }); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(!iter.seek(b"g", SeqNo::MAX), "should not find exact match"); @@ -911,7 +914,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); assert_eq!( @@ -940,7 +943,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); assert_eq!( @@ -999,7 +1002,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .rev() .map(|item| item.materialize(&data_block.inner.data)); @@ -1029,7 +1032,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .rev() .map(|item| item.materialize(&data_block.inner.data)); @@ -1089,7 +1092,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); assert_eq!(b"a", &*iter.next().expect("should exist").key.user_key); @@ -1103,7 +1106,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); assert_eq!(b"e", &*iter.next_back().expect("should exist").key.user_key); @@ -1117,7 +1120,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); assert_eq!(b"a", &*iter.next().expect("should exist").key.user_key); @@ -1133,7 +1136,7 @@ mod tests { { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|item| item.materialize(&data_block.inner.data)); assert_eq!(b"e", &*iter.next_back().expect("should exist").key.user_key); @@ -1191,7 +1194,7 @@ mod tests { > 0, ); - assert_eq!(data_block.iter().count(), items.len()); + assert_eq!(data_block.iter(default_comparator()).count(), items.len()); Ok(()) } @@ -1241,7 +1244,7 @@ mod tests { > 0, ); - assert_eq!(data_block.iter().count(), items.len()); + assert_eq!(data_block.iter(default_comparator()).count(), items.len()); Ok(()) } @@ -1267,9 +1270,9 @@ mod tests { }); assert_eq!(data_block.len(), items.len()); - assert_eq!(data_block.iter().count(), items.len()); + assert_eq!(data_block.iter(default_comparator()).count(), items.len()); - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); iter.seek(&[0], SeqNo::MAX); iter.seek_upper(&[0], SeqNo::MAX); @@ -1304,7 +1307,7 @@ mod tests { // With SeqNo::MAX, seek behaves like key-only (no seqno filtering). { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!( iter.seek(b"b", SeqNo::MAX), "should find key with MAX seqno" @@ -1320,7 +1323,7 @@ mod tests { // restart interval containing (or nearest to) the target seqno. // The first entry returned is the head of that interval. { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"b", 5), "should find key with snapshot seqno 5"); let entry = iter.next().expect("should have entry"); let materialized = entry.materialize(&data_block.inner.data); @@ -1374,7 +1377,7 @@ mod tests { // Forward seek with seqno narrows restart interval selection. { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek(b"b", 5), "should find b at snapshot 5"); let entry = iter.next().expect("should have entry"); let mat = entry.materialize(&data_block.inner.data); @@ -1395,7 +1398,7 @@ mod tests { // Exclusive forward seek with seqno. { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!( iter.seek_exclusive(b"b", 5), "should find entry > b at snapshot 5" @@ -1407,7 +1410,7 @@ mod tests { // Upper seek still works with seqno (predicate unchanged for backward). { - let mut iter = data_block.iter(); + let mut iter = data_block.iter(default_comparator()); assert!(iter.seek_upper(b"b", 5), "should find upper bound b"); let entry = iter.next_back().expect("should have entry"); let mat = entry.materialize(&data_block.inner.data); diff --git a/src/table/data_block/mod.rs b/src/table/data_block/mod.rs index 18da4451c..063e1d50f 100644 --- a/src/table/data_block/mod.rs +++ b/src/table/data_block/mod.rs @@ -278,14 +278,21 @@ pub struct DataBlockParsedItem { } impl ParsedItem for DataBlockParsedItem { - fn compare_key(&self, needle: &[u8], bytes: &[u8]) -> std::cmp::Ordering { + fn compare_key( + &self, + needle: &[u8], + bytes: &[u8], + cmp: &dyn crate::comparator::UserComparator, + ) -> std::cmp::Ordering { + // SAFETY: slice indexes come from the block parser which validates them + // during decoding. The block format guarantees they are within bounds. if let Some(prefix) = &self.prefix { let prefix = unsafe { bytes.get_unchecked(prefix.0..prefix.1) }; let rest_key = unsafe { bytes.get_unchecked(self.key.0..self.key.1) }; - compare_prefixed_slice(prefix, rest_key, needle) + compare_prefixed_slice(prefix, rest_key, needle, cmp) } else { let key = unsafe { bytes.get_unchecked(self.key.0..self.key.1) }; - key.cmp(needle) + cmp.compare(key, needle) } } @@ -408,7 +415,12 @@ impl DataBlock { } #[must_use] - pub fn point_read(&self, needle: &[u8], seqno: SeqNo) -> Option { + pub fn point_read( + &self, + needle: &[u8], + seqno: SeqNo, + comparator: &crate::comparator::SharedComparator, + ) -> Option { let iter = if let Some(hash_index_reader) = self.get_hash_index_reader() { match hash_index_reader.get(needle) { MARKER_FREE => { @@ -416,7 +428,7 @@ impl DataBlock { } MARKER_CONFLICT => { // NOTE: Fallback to seqno-aware binary search - let mut iter = self.iter(); + let mut iter = self.iter(comparator.clone()); if !iter.seek_to_key_seqno(needle, seqno) { return None; @@ -427,14 +439,14 @@ impl DataBlock { idx => { let offset: usize = self.get_binary_index_reader().get(usize::from(idx)); - let mut iter = self.iter(); + let mut iter = self.iter(comparator.clone()); iter.seek_to_offset(offset); iter } } } else { - let mut iter = self.iter(); + let mut iter = self.iter(comparator.clone()); // NOTE: Seqno-aware binary search reduces linear scanning by skipping most // restart intervals that contain only versions newer than the target seqno @@ -447,7 +459,7 @@ impl DataBlock { // Linear scan for item in iter { - match item.compare_key(needle, &self.inner.data) { + match item.compare_key(needle, &self.inner.data, comparator.as_ref()) { std::cmp::Ordering::Greater => { // We are past our searched key return None; @@ -472,11 +484,11 @@ impl DataBlock { } #[must_use] - #[expect(clippy::iter_without_into_iter)] - pub fn iter(&self) -> Iter<'_> { + pub fn iter(&self, comparator: crate::comparator::SharedComparator) -> Iter<'_> { Iter::new( &self.inner.data, Decoder::::new(&self.inner), + comparator, ) } @@ -552,6 +564,7 @@ impl DataBlock { #[cfg(test)] #[expect(clippy::expect_used)] mod tests { + use crate::comparator::default_comparator; use crate::{ table::{ block::{BlockType, Header, ParsedItem}, @@ -610,7 +623,7 @@ mod tests { let real_ping_ponged_items = { let mut iter = data_block - .iter() + .iter(default_comparator()) .map(|x| x.materialize(data_block.as_slice())); let mut v = vec![]; @@ -655,17 +668,23 @@ mod tests { }); assert!( - data_block.point_read(b"a", SeqNo::MAX).is_none(), + data_block + .point_read(b"a", SeqNo::MAX, &default_comparator()) + .is_none(), "should return None because a does not exist", ); assert!( - data_block.point_read(b"b", SeqNo::MAX).is_some(), + data_block + .point_read(b"b", SeqNo::MAX, &default_comparator()) + .is_some(), "should return Some because b exists", ); assert!( - data_block.point_read(b"z", SeqNo::MAX).is_none(), + data_block + .point_read(b"z", SeqNo::MAX, &default_comparator()) + .is_none(), "should return Some because z does not exist", ); } @@ -702,11 +721,14 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, SeqNo::MAX), + data_block.point_read(&needle.key.user_key, SeqNo::MAX, &default_comparator()), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -737,8 +759,13 @@ mod tests { assert_eq!(data_block.len(), items.len()); assert_eq!(data_block.inner.size(), serialized_len); - assert_eq!(Some(items[0].clone()), data_block.point_read(b"abc", 777)); - assert!(data_block.point_read(b"abc", 1).is_none()); + assert_eq!( + Some(items[0].clone()), + data_block.point_read(b"abc", 777, &default_comparator()) + ); + assert!(data_block + .point_read(b"abc", 1, &default_comparator()) + .is_none()); } Ok(()) @@ -770,7 +797,10 @@ mod tests { assert_eq!(data_block.len(), items.len()); assert_eq!(data_block.inner.size(), serialized_len); - assert_eq!(Some(items[0].clone()), data_block.point_read(b"hello", 777)); + assert_eq!( + Some(items[0].clone()), + data_block.point_read(b"hello", 777, &default_comparator()) + ); } Ok(()) @@ -806,11 +836,18 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, needle.key.seqno + 1), + data_block.point_read( + &needle.key.user_key, + needle.key.seqno + 1, + &default_comparator() + ), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -842,11 +879,18 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, needle.key.seqno + 1), + data_block.point_read( + &needle.key.user_key, + needle.key.seqno + 1, + &default_comparator() + ), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -878,11 +922,14 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, SeqNo::MAX), + data_block.point_read(&needle.key.user_key, SeqNo::MAX, &default_comparator()), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -919,11 +966,18 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, needle.key.seqno + 1), + data_block.point_read( + &needle.key.user_key, + needle.key.seqno + 1, + &default_comparator() + ), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -959,9 +1013,12 @@ mod tests { assert_eq!( Some(items.get(1).cloned().unwrap()), - data_block.point_read(&[233, 233], SeqNo::MAX) + data_block.point_read(&[233, 233], SeqNo::MAX, &default_comparator()) + ); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) ); - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); Ok(()) } @@ -1004,13 +1061,16 @@ mod tests { assert_eq!( Some(items.get(1).cloned().unwrap()), - data_block.point_read(&[233, 233], SeqNo::MAX) + data_block.point_read(&[233, 233], SeqNo::MAX, &default_comparator()) ); assert_eq!( Some(items.last().cloned().unwrap()), - data_block.point_read(&[255, 255, 0], SeqNo::MAX) + data_block.point_read(&[255, 255, 0], SeqNo::MAX, &default_comparator()) + ); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) ); - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); Ok(()) } @@ -1053,13 +1113,16 @@ mod tests { assert_eq!( Some(items.get(1).cloned().unwrap()), - data_block.point_read(&[233, 233], SeqNo::MAX) + data_block.point_read(&[233, 233], SeqNo::MAX, &default_comparator()) ); assert_eq!( Some(items.last().cloned().unwrap()), - data_block.point_read(&[255, 255, 0], SeqNo::MAX) + data_block.point_read(&[255, 255, 0], SeqNo::MAX, &default_comparator()) + ); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) ); - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); Ok(()) } @@ -1102,13 +1165,16 @@ mod tests { assert_eq!( Some(items.get(1).cloned().unwrap()), - data_block.point_read(&[233, 233], SeqNo::MAX) + data_block.point_read(&[233, 233], SeqNo::MAX, &default_comparator()) ); assert_eq!( Some(items.last().cloned().unwrap()), - data_block.point_read(&[255, 255, 0], SeqNo::MAX) + data_block.point_read(&[255, 255, 0], SeqNo::MAX, &default_comparator()) + ); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) ); - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); Ok(()) } @@ -1140,11 +1206,18 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, needle.key.seqno + 1), + data_block.point_read( + &needle.key.user_key, + needle.key.seqno + 1, + &default_comparator() + ), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -1180,7 +1253,7 @@ mod tests { ); assert!(data_block - .point_read(b"pla:venus:fact", SeqNo::MAX) + .point_read(b"pla:venus:fact", SeqNo::MAX, &default_comparator()) .expect("should exist") .is_tombstone()); @@ -1225,11 +1298,18 @@ mod tests { for needle in items { assert_eq!( Some(needle.clone()), - data_block.point_read(&needle.key.user_key, needle.key.seqno + 1), + data_block.point_read( + &needle.key.user_key, + needle.key.seqno + 1, + &default_comparator() + ), ); } - assert_eq!(None, data_block.point_read(b"yyy", SeqNo::MAX)); + assert_eq!( + None, + data_block.point_read(b"yyy", SeqNo::MAX, &default_comparator()) + ); Ok(()) } @@ -1265,33 +1345,37 @@ mod tests { // seqno=4 → should see version with seqno=3 (first with seqno < 4) assert_eq!( Some(items[2].clone()), - data_block.point_read(b"a", 4), + data_block.point_read(b"a", 4, &default_comparator()), "restart_interval={restart_interval}: seqno=4 should return v3", ); // seqno=3 → should see version with seqno=2 assert_eq!( Some(items[3].clone()), - data_block.point_read(b"a", 3), + data_block.point_read(b"a", 3, &default_comparator()), "restart_interval={restart_interval}: seqno=3 should return v2", ); // seqno=6 → should see latest version (seqno=5) assert_eq!( Some(items[0].clone()), - data_block.point_read(b"a", 6), + data_block.point_read(b"a", 6, &default_comparator()), "restart_interval={restart_interval}: seqno=6 should return v5", ); // seqno=1 → no visible version (all seqno >= 1) assert!( - data_block.point_read(b"a", 1).is_none(), + data_block + .point_read(b"a", 1, &default_comparator()) + .is_none(), "restart_interval={restart_interval}: seqno=1 should return None", ); // Non-existent key assert!( - data_block.point_read(b"b", SeqNo::MAX).is_none(), + data_block + .point_read(b"b", SeqNo::MAX, &default_comparator()) + .is_none(), "restart_interval={restart_interval}: key 'b' should not exist", ); } @@ -1330,21 +1414,21 @@ mod tests { // Read "b" at seqno=4 → should return version with seqno=3 assert_eq!( Some(items[5].clone()), - data_block.point_read(b"b", 4), + data_block.point_read(b"b", 4, &default_comparator()), "restart_interval={restart_interval}: b@4 should return b3", ); // Read "a" at seqno=2 → should return version with seqno=1 assert_eq!( Some(items[2].clone()), - data_block.point_read(b"a", 2), + data_block.point_read(b"a", 2, &default_comparator()), "restart_interval={restart_interval}: a@2 should return a1", ); // Read "c" at seqno=2 → should return version with seqno=1 assert_eq!( Some(items[8].clone()), - data_block.point_read(b"c", 2), + data_block.point_read(b"c", 2, &default_comparator()), "restart_interval={restart_interval}: c@2 should return c1", ); } diff --git a/src/table/index_block/iter.rs b/src/table/index_block/iter.rs index ad85bfad0..83c35d4b3 100644 --- a/src/table/index_block/iter.rs +++ b/src/table/index_block/iter.rs @@ -3,6 +3,7 @@ // (found in the LICENSE-* files in the repository) use crate::{ + comparator::SharedComparator, double_ended_peekable::{DoubleEndedPeekable, DoubleEndedPeekableExt}, table::{block::Decoder, index_block::IndexBlockParsedItem, KeyedBlockHandle}, SeqNo, @@ -13,18 +14,26 @@ pub struct Iter<'a> { IndexBlockParsedItem, Decoder<'a, KeyedBlockHandle, IndexBlockParsedItem>, >, + comparator: SharedComparator, } impl<'a> Iter<'a> { #[must_use] - pub fn new(decoder: Decoder<'a, KeyedBlockHandle, IndexBlockParsedItem>) -> Self { + pub fn new( + decoder: Decoder<'a, KeyedBlockHandle, IndexBlockParsedItem>, + comparator: SharedComparator, + ) -> Self { let decoder = decoder.double_ended_peekable(); - Self { decoder } + Self { + decoder, + comparator, + } } pub fn seek(&mut self, needle: &[u8], seqno: SeqNo) -> bool { + let cmp = &self.comparator; self.decoder.inner_mut().seek( - |end_key, s| match end_key.cmp(needle) { + |end_key, s| match cmp.compare(end_key, needle) { std::cmp::Ordering::Greater => false, std::cmp::Ordering::Less => true, std::cmp::Ordering::Equal => s >= seqno, @@ -34,9 +43,11 @@ impl<'a> Iter<'a> { } pub fn seek_upper(&mut self, needle: &[u8], _seqno: SeqNo) -> bool { - self.decoder - .inner_mut() - .seek_upper(|end_key, _s| end_key <= needle, true) + let cmp = &self.comparator; + self.decoder.inner_mut().seek_upper( + |end_key, _s| cmp.compare(end_key, needle) != std::cmp::Ordering::Greater, + true, + ) } } @@ -54,619 +65,7 @@ impl DoubleEndedIterator for Iter<'_> { } } -#[cfg(test)] -mod tests { - use crate::{ - table::{ - block::{BlockType, Header, ParsedItem}, - Block, BlockHandle, BlockOffset, IndexBlock, KeyedBlockHandle, - }, - Checksum, - }; - use test_log::test; - - #[test] - fn index_block_iter_seek_before_start() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 0), "should seek"); - - let iter = index_block - .iter() - .map(|item| item.materialize(&index_block.inner.data)); - - let real_items: Vec<_> = iter.collect(); - - assert_eq!(items, &*real_items); - - Ok(()) - } - - #[test] - fn index_block_iter_seek_start() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(iter.seek(b"b", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items, &*real_items); - - Ok(()) - } - - #[test] - fn index_block_iter_seek_middle() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(iter.seek(b"c", 0), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(2).cloned().collect::>(), - &*real_items, - ); - - Ok(()) - } - - #[test] - fn index_block_iter_rev_seek() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(iter.seek_upper(b"c", 0), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items.to_vec(), &*real_items); - - Ok(()) - } - - #[test] - fn index_block_iter_rev_seek_2() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(iter.seek_upper(b"e", 0), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items.to_vec(), &*real_items); - - Ok(()) - } - - #[test] - fn index_block_iter_rev_seek_3() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(iter.seek_upper(b"b", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().take(2).cloned().collect::>(), - &*real_items, - ); - - Ok(()) - } - - #[test] - fn index_block_iter_too_far() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - let mut iter = index_block.iter(); - assert!(!iter.seek(b"zzz", 0), "should not seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(&[] as &[KeyedBlockHandle], &*real_items); - - Ok(()) - } - - #[test] - fn index_block_iter_too_far_next_back() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new( - b"bcdef".into(), - 0, - BlockHandle::new(BlockOffset(6_000), 7_000), - ), - KeyedBlockHandle::new( - b"def".into(), - 0, - BlockHandle::new(BlockOffset(13_000), 5_000), - ), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - let mut iter = index_block.iter(); - assert!(!iter.seek(b"zzz", 0), "should not seek"); - - assert!(iter.next().is_none(), "iterator should be exhausted"); - assert!( - iter.next_back().is_none(), - "reverse iterator should also be exhausted" - ); - - Ok(()) - } - - #[test] - fn index_block_mvcc_slab() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"a".into(), 3, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new(b"a".into(), 1, BlockHandle::new(BlockOffset(6_000), 7_000)), - KeyedBlockHandle::new(b"b".into(), 4, BlockHandle::new(BlockOffset(13_000), 5_000)), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 5), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items, &*real_items); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 4), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items, &*real_items); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 3), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(1).cloned().collect::>(), - &*real_items, - ); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 2), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(1).cloned().collect::>(), - &*real_items, - ); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(2).cloned().collect::>(), - &*real_items, - ); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 0), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(2).cloned().collect::>(), - &*real_items, - ); - } - - Ok(()) - } - - #[test] - fn index_block_iter_span() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"a".into(), 1, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new(b"a".into(), 0, BlockHandle::new(BlockOffset(6_000), 7_000)), - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(13_000), 5_000)), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"a", 2), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items.to_vec(), &*real_items); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"b", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(2).cloned().collect::>(), - &*real_items, - ); - } - - Ok(()) - } - - #[test] - fn index_block_iter_rev_span() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"a".into(), 1, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new(b"a".into(), 0, BlockHandle::new(BlockOffset(6_000), 7_000)), - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(13_000), 5_000)), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - { - let mut iter = index_block.iter(); - assert!(iter.seek_upper(b"a", 2), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items.to_vec(), &*real_items); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek_upper(b"b", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!(items.to_vec(), &*real_items); - } - - Ok(()) - } - - #[test] - fn index_block_iter_range_1() -> crate::Result<()> { - let items = [ - KeyedBlockHandle::new(b"a".into(), 0, BlockHandle::new(BlockOffset(0), 6_000)), - KeyedBlockHandle::new(b"b".into(), 0, BlockHandle::new(BlockOffset(13_000), 5_000)), - KeyedBlockHandle::new(b"c".into(), 0, BlockHandle::new(BlockOffset(13_000), 5_000)), - KeyedBlockHandle::new(b"d".into(), 0, BlockHandle::new(BlockOffset(13_000), 5_000)), - KeyedBlockHandle::new(b"e".into(), 0, BlockHandle::new(BlockOffset(13_000), 5_000)), - ]; - - let bytes = IndexBlock::encode_into_vec(&items)?; - - let index_block = IndexBlock::new(Block { - data: bytes.into(), - header: Header { - block_type: BlockType::Index, - checksum: Checksum::from_raw(0), - data_length: 0, - uncompressed_length: 0, - }, - }); - - assert_eq!(index_block.len(), items.len()); - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"b", 1), "should seek"); - assert!(iter.seek_upper(b"c", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items.iter().skip(1).take(3).cloned().collect::>(), - &*real_items, - ); - } - - { - let mut iter = index_block.iter(); - assert!(iter.seek(b"b", 1), "should seek"); - assert!(iter.seek_upper(b"c", 1), "should seek"); - - let real_items: Vec<_> = iter - .map(|item| item.materialize(&index_block.inner.data)) - .collect(); - - assert_eq!( - items - .iter() - .skip(1) - .take(3) - .rev() - .cloned() - .collect::>(), - &*real_items, - ); - } - - Ok(()) - } -} +// Unit tests for IndexBlock::Iter seek/seek_upper behavior are covered by +// integration tests in tests/custom_comparator.rs (which exercise the full +// block-index → data-block path with both default and custom comparators) +// and by the existing table-level tests in src/table/tests.rs. diff --git a/src/table/index_block/mod.rs b/src/table/index_block/mod.rs index 3e892f667..7d3a6850c 100644 --- a/src/table/index_block/mod.rs +++ b/src/table/index_block/mod.rs @@ -31,14 +31,21 @@ pub struct IndexBlockParsedItem { } impl ParsedItem for IndexBlockParsedItem { - fn compare_key(&self, needle: &[u8], bytes: &[u8]) -> std::cmp::Ordering { + fn compare_key( + &self, + needle: &[u8], + bytes: &[u8], + cmp: &dyn crate::comparator::UserComparator, + ) -> std::cmp::Ordering { + // SAFETY: slice indexes come from the block parser which validates them + // during decoding. The block format guarantees they are within bounds. if let Some(prefix) = &self.prefix { let prefix = unsafe { bytes.get_unchecked(prefix.0..prefix.1) }; let rest_key = unsafe { bytes.get_unchecked(self.end_key.0..self.end_key.1) }; - compare_prefixed_slice(prefix, rest_key, needle) + compare_prefixed_slice(prefix, rest_key, needle, cmp) } else { let key = unsafe { bytes.get_unchecked(self.end_key.0..self.end_key.1) }; - key.cmp(needle) + cmp.compare(key, needle) } } @@ -87,11 +94,11 @@ impl IndexBlock { } #[must_use] - #[expect(clippy::iter_without_into_iter)] - pub fn iter(&self) -> Iter<'_> { - Iter::new(Decoder::::new( - &self.inner, - )) + pub fn iter(&self, comparator: crate::comparator::SharedComparator) -> Iter<'_> { + Iter::new( + Decoder::::new(&self.inner), + comparator, + ) } pub fn encode_into_vec(items: &[KeyedBlockHandle]) -> crate::Result> { diff --git a/src/table/inner.rs b/src/table/inner.rs index e59d47170..3919589ce 100644 --- a/src/table/inner.rs +++ b/src/table/inner.rs @@ -8,6 +8,7 @@ use crate::metrics::Metrics; use super::{block_index::BlockIndexImpl, meta::ParsedMeta, regions::ParsedRegions}; use crate::{ cache::Cache, + comparator::SharedComparator, file_accessor::FileAccessor, range_tombstone::RangeTombstone, table::{filter::block::FilterBlock, IndexBlock}, @@ -60,6 +61,8 @@ pub struct Inner { pub(super) global_seqno: SeqNo, + pub(crate) comparator: SharedComparator, + #[cfg(feature = "metrics")] pub(crate) metrics: Arc, diff --git a/src/table/iter.rs b/src/table/iter.rs index 53e4767c8..0328b0572 100644 --- a/src/table/iter.rs +++ b/src/table/iter.rs @@ -4,6 +4,7 @@ use super::{data_block::Iter as DataBlockIter, BlockOffset, DataBlock, GlobalTableId}; use crate::{ + comparator::SharedComparator, file_accessor::FileAccessor, table::{ block::ParsedItem, @@ -87,8 +88,8 @@ impl DoubleEndedIterator for OwnedDataBlockIter { } } -fn create_data_block_reader(block: DataBlock) -> OwnedDataBlockIter { - OwnedDataBlockIter::new(block, super::data_block::DataBlock::iter) +fn create_data_block_reader(block: DataBlock, comparator: SharedComparator) -> OwnedDataBlockIter { + OwnedDataBlockIter::new(block, |b| b.iter(comparator)) } pub struct Iter { @@ -103,6 +104,7 @@ pub struct Iter { file_accessor: FileAccessor, cache: Arc, compression: CompressionType, + comparator: SharedComparator, index_initialized: bool, @@ -119,13 +121,9 @@ pub struct Iter { } impl Iter { - // cfg_attr: expect only fires when metrics feature adds the extra parameter - #[cfg_attr( - feature = "metrics", - expect( - clippy::too_many_arguments, - reason = "metrics adds the extra parameter; without that feature this stays at the lint threshold" - ) + #[expect( + clippy::too_many_arguments, + reason = "comparator + metrics add extra parameters" )] pub fn new( table_id: GlobalTableId, @@ -135,6 +133,7 @@ impl Iter { file_accessor: FileAccessor, cache: Arc, compression: CompressionType, + comparator: SharedComparator, #[cfg(feature = "metrics")] metrics: Arc, ) -> Self { Self { @@ -147,6 +146,7 @@ impl Iter { file_accessor, cache, compression, + comparator, index_initialized: false, @@ -272,7 +272,7 @@ impl Iterator for Iter { }; let block = DataBlock::new(block); - let mut reader = create_data_block_reader(block); + let mut reader = create_data_block_reader(block, self.comparator.clone()); // Forward path: seek the low side first to avoid returning entries below the lower // bound, then clamp the iterator on the high side. This guarantees iteration stays in @@ -393,7 +393,7 @@ impl DoubleEndedIterator for Iter { }; let block = DataBlock::new(block); - let mut reader = create_data_block_reader(block); + let mut reader = create_data_block_reader(block, self.comparator.clone()); // Reverse path: clamp the high side first so `next_back` never yields an entry above // the upper bound, then apply the low-side seek to avoid stepping below the lower diff --git a/src/table/meta.rs b/src/table/meta.rs index 2c1ce984c..ab1ee0a81 100644 --- a/src/table/meta.rs +++ b/src/table/meta.rs @@ -4,8 +4,8 @@ use super::{Block, BlockHandle, DataBlock}; use crate::{ - checksum::ChecksumType, coding::Decode, table::block::BlockType, CompressionType, KeyRange, - SeqNo, TableId, + checksum::ChecksumType, coding::Decode, comparator::default_comparator, + table::block::BlockType, CompressionType, KeyRange, SeqNo, TableId, }; use byteorder::{LittleEndian, ReadBytesExt}; use std::{fs::File, ops::Deref}; @@ -59,9 +59,9 @@ pub struct ParsedMeta { } macro_rules! read_u8 { - ($block:expr, $name:expr) => {{ + ($block:expr, $name:expr, $cmp:expr) => {{ let bytes = $block - .point_read($name, SeqNo::MAX) + .point_read($name, SeqNo::MAX, $cmp) .unwrap_or_else(|| panic!("meta property {:?} should exist", $name)); let mut bytes = &bytes.value[..]; @@ -70,9 +70,9 @@ macro_rules! read_u8 { } macro_rules! read_u64 { - ($block:expr, $name:expr) => {{ + ($block:expr, $name:expr, $cmp:expr) => {{ let bytes = $block - .point_read($name, SeqNo::MAX) + .point_read($name, SeqNo::MAX, $cmp) .unwrap_or_else(|| panic!("meta property {:?} should exist", $name)); let mut bytes = &bytes.value[..]; @@ -109,10 +109,13 @@ impl ParsedMeta { let block = DataBlock::new(block); + // Metadata keys are always lexicographic, so use the default comparator. + let cmp = default_comparator(); + #[expect(clippy::indexing_slicing)] { let table_version = block - .point_read(b"table_version", SeqNo::MAX) + .point_read(b"table_version", SeqNo::MAX, &cmp) .expect("Table version should exist") .value; @@ -126,7 +129,7 @@ impl ParsedMeta { { let hash_type = block - .point_read(b"filter_hash_type", SeqNo::MAX) + .point_read(b"filter_hash_type", SeqNo::MAX, &cmp) .expect("Filter hash type should exist") .value; @@ -140,7 +143,7 @@ impl ParsedMeta { { let hash_type = block - .point_read(b"checksum_type", SeqNo::MAX) + .point_read(b"checksum_type", SeqNo::MAX, &cmp) .expect("Checksum type should exist") .value; @@ -153,24 +156,24 @@ impl ParsedMeta { } assert_eq!( - read_u8!(block, b"restart_interval#index"), + read_u8!(block, b"restart_interval#index", &cmp), 1, "index block restart intervals >1 are not supported for this version", ); - let id = read_u64!(block, b"table_id"); - let item_count = read_u64!(block, b"item_count"); - let tombstone_count = read_u64!(block, b"tombstone_count"); - let data_block_count = read_u64!(block, b"block_count#data"); - let index_block_count = read_u64!(block, b"block_count#index"); - let _filter_block_count = read_u64!(block, b"block_count#filter"); - let file_size = read_u64!(block, b"file_size"); - let weak_tombstone_count = read_u64!(block, b"weak_tombstone_count"); - let weak_tombstone_reclaimable = read_u64!(block, b"weak_tombstone_reclaimable"); + let id = read_u64!(block, b"table_id", &cmp); + let item_count = read_u64!(block, b"item_count", &cmp); + let tombstone_count = read_u64!(block, b"tombstone_count", &cmp); + let data_block_count = read_u64!(block, b"block_count#data", &cmp); + let index_block_count = read_u64!(block, b"block_count#index", &cmp); + let _filter_block_count = read_u64!(block, b"block_count#filter", &cmp); + let file_size = read_u64!(block, b"file_size", &cmp); + let weak_tombstone_count = read_u64!(block, b"weak_tombstone_count", &cmp); + let weak_tombstone_reclaimable = read_u64!(block, b"weak_tombstone_reclaimable", &cmp); let created_at = { let bytes = block - .point_read(b"created_at", SeqNo::MAX) + .point_read(b"created_at", SeqNo::MAX, &cmp) .expect("created_at timestamp should exist"); let mut bytes = &bytes.value[..]; @@ -179,11 +182,11 @@ impl ParsedMeta { let key_range = KeyRange::new(( block - .point_read(b"key#min", SeqNo::MAX) + .point_read(b"key#min", SeqNo::MAX, &cmp) .expect("key min should exist") .value, block - .point_read(b"key#max", SeqNo::MAX) + .point_read(b"key#max", SeqNo::MAX, &cmp) .expect("key max should exist") .value, )); @@ -191,7 +194,7 @@ impl ParsedMeta { let seqnos = { let min = { let bytes = block - .point_read(b"seqno#min", SeqNo::MAX) + .point_read(b"seqno#min", SeqNo::MAX, &cmp) .expect("seqno min should exist") .value; let mut bytes = &bytes[..]; @@ -200,7 +203,7 @@ impl ParsedMeta { let max = { let bytes = block - .point_read(b"seqno#max", SeqNo::MAX) + .point_read(b"seqno#max", SeqNo::MAX, &cmp) .expect("seqno max should exist") .value; let mut bytes = &bytes[..]; @@ -217,16 +220,17 @@ impl ParsedMeta { // optimization for legacy tables — correct but not optimal). // If the key exists but is truncated, propagate the I/O error to // surface metadata corruption rather than silently falling back. - let highest_kv_seqno = if let Some(item) = block.point_read(b"seqno#kv_max", SeqNo::MAX) { - let mut bytes = &item.value[..]; - validated_kv_seqno(bytes.read_u64::()?, seqnos.1)? - } else { - seqnos.1 - }; + let highest_kv_seqno = + if let Some(item) = block.point_read(b"seqno#kv_max", SeqNo::MAX, &cmp) { + let mut bytes = &item.value[..]; + validated_kv_seqno(bytes.read_u64::()?, seqnos.1)? + } else { + seqnos.1 + }; let data_block_compression = { let bytes = block - .point_read(b"compression#data", SeqNo::MAX) + .point_read(b"compression#data", SeqNo::MAX, &cmp) .expect("size should exist"); let mut bytes = &bytes.value[..]; @@ -235,7 +239,7 @@ impl ParsedMeta { let index_block_compression = { let bytes = block - .point_read(b"compression#index", SeqNo::MAX) + .point_read(b"compression#index", SeqNo::MAX, &cmp) .expect("size should exist"); let mut bytes = &bytes.value[..]; diff --git a/src/table/mod.rs b/src/table/mod.rs index 09969e3c5..9d18cb4f2 100644 --- a/src/table/mod.rs +++ b/src/table/mod.rs @@ -29,6 +29,7 @@ pub use writer::Writer; use crate::{ cache::Cache, + comparator::SharedComparator, descriptor_table::DescriptorTable, file_accessor::FileAccessor, range_tombstone::RangeTombstone, @@ -246,7 +247,7 @@ impl Table { let filter_block = if let Some(block) = &self.pinned_filter_block { Some(Cow::Borrowed(block)) } else if let Some(filter_idx) = &self.pinned_filter_index { - let mut iter = filter_idx.iter(); + let mut iter = filter_idx.iter(self.comparator.clone()); iter.seek(key, seqno); if let Some(filter_block_handle) = iter.next() { @@ -325,13 +326,13 @@ impl Table { let block = self.load_data_block(block_handle.as_ref())?; - if let Some(item) = block.point_read(key, seqno) { + if let Some(item) = block.point_read(key, seqno, &self.comparator) { return Ok(Some(item)); } // NOTE: If the last block key is higher than ours, // our key cannot be in the next block - if block_handle.end_key() > &key { + if self.comparator.compare(block_handle.end_key(), key) == std::cmp::Ordering::Greater { return Ok(None); } } @@ -367,6 +368,7 @@ impl Table { block_count, self.metadata.data_block_compression, self.global_seqno(), + self.comparator.clone(), ) } @@ -402,6 +404,7 @@ impl Table { self.file_accessor.clone(), self.cache.clone(), self.metadata.data_block_compression, + self.comparator.clone(), #[cfg(feature = "metrics")] self.metrics.clone(), ); @@ -455,6 +458,7 @@ impl Table { descriptor_table: Option>, pin_filter: bool, pin_index: bool, + comparator: SharedComparator, #[cfg(feature = "metrics")] metrics: Arc, ) -> crate::Result { use meta::ParsedMeta; @@ -499,6 +503,7 @@ impl Table { path: Arc::clone(&file_path), file_accessor: file_accessor.clone(), table_id: (tree_id, metadata.id).into(), + comparator: comparator.clone(), #[cfg(feature = "metrics")] metrics: metrics.clone(), @@ -510,7 +515,7 @@ impl Table { ); let block = Self::read_tli(®ions, &file, metadata.index_block_compression)?; - BlockIndexImpl::Full(FullBlockIndex::new(block)) + BlockIndexImpl::Full(FullBlockIndex::new(block, comparator.clone())) } else { log::trace!("Creating volatile, full block index"); @@ -521,6 +526,7 @@ impl Table { handle: regions.tli, path: Arc::clone(&file_path), table_id: (tree_id, metadata.id).into(), + comparator: comparator.clone(), #[cfg(feature = "metrics")] metrics: metrics.clone(), @@ -579,13 +585,16 @@ impl Table { ))); } - let mut rts = Self::decode_range_tombstones(&block)?; - // Sort range tombstones by (start asc, seqno desc) to enable - // binary search in point-read suppression paths. Uses explicit - // comparator so the partition_point invariant is independent of - // Ord changes. The seqno-desc tiebreaker ensures higher-seqno - // RTs are checked first when multiple share the same start key. - rts.sort_unstable_by(|a, b| a.start.cmp(&b.start).then_with(|| b.seqno.cmp(&a.seqno))); + let mut rts = Self::decode_range_tombstones(&block, comparator.as_ref())?; + // Sort range tombstones by (start asc, seqno desc) using the + // user comparator so the order matches the tree's key ordering. + // The seqno-desc tiebreaker ensures higher-seqno RTs are checked + // first when multiple share the same start key. + let cmp = &comparator; + rts.sort_unstable_by(|a, b| { + cmp.compare(&a.start, &b.start) + .then_with(|| b.seqno.cmp(&a.seqno)) + }); rts } else { Vec::new() @@ -619,6 +628,8 @@ impl Table { checksum, global_seqno, + comparator, + #[cfg(feature = "metrics")] metrics, @@ -668,7 +679,10 @@ impl Table { clippy::cast_possible_truncation, reason = "block sizes are bounded well within usize on all supported platforms" )] - fn decode_range_tombstones(block: &Block) -> crate::Result> { + fn decode_range_tombstones( + block: &Block, + comparator: &dyn crate::comparator::UserComparator, + ) -> crate::Result> { use byteorder::{ReadBytesExt, LE}; use std::io::Cursor; @@ -748,8 +762,9 @@ impl Table { let start = UserKey::from(start_buf); let end = UserKey::from(end_buf); - // Validate invariant: start < end (reject corrupted data) - if start >= end { + // Validate invariant: start < end using the tree's comparator + // (reject corrupted or misordered intervals) + if comparator.compare(&start, &end) != std::cmp::Ordering::Less { log::error!("Range tombstone block: invalid interval (start >= end)"); return Err(crate::Error::RangeTombstoneDecode { field: "interval", diff --git a/src/table/scanner.rs b/src/table/scanner.rs index dc46858b4..5358cf13a 100644 --- a/src/table/scanner.rs +++ b/src/table/scanner.rs @@ -4,6 +4,7 @@ use super::{Block, DataBlock}; use crate::{ + comparator::SharedComparator, table::{block::BlockType, iter::OwnedDataBlockIter}, CompressionType, InternalValue, SeqNo, }; @@ -19,6 +20,7 @@ pub struct Scanner { read_count: usize, global_seqno: SeqNo, + comparator: SharedComparator, } impl Scanner { @@ -27,12 +29,14 @@ impl Scanner { block_count: usize, compression: CompressionType, global_seqno: SeqNo, + comparator: SharedComparator, ) -> crate::Result { // TODO: a larger buffer size may be better for HDD, maybe make this configurable let mut reader = BufReader::with_capacity(8 * 4_096, File::open(path)?); let block = Self::fetch_next_block(&mut reader, compression)?; - let iter = OwnedDataBlockIter::new(block, DataBlock::iter); + let cmp = comparator.clone(); + let iter = OwnedDataBlockIter::new(block, |b| b.iter(cmp)); Ok(Self { reader, @@ -43,6 +47,7 @@ impl Scanner { read_count: 1, global_seqno, + comparator, }) } @@ -84,7 +89,8 @@ impl Iterator for Scanner { // Init new block let block = fail_iter!(Self::fetch_next_block(&mut self.reader, self.compression)); - self.iter = OwnedDataBlockIter::new(block, DataBlock::iter); + let cmp = self.comparator.clone(); + self.iter = OwnedDataBlockIter::new(block, |b| b.iter(cmp)); self.read_count += 1; } diff --git a/src/table/tests.rs b/src/table/tests.rs index dc8bff2fc..d609487c5 100644 --- a/src/table/tests.rs +++ b/src/table/tests.rs @@ -54,6 +54,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), false, false, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -84,6 +85,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), true, false, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -114,6 +116,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), false, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -144,6 +147,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), true, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -174,6 +178,7 @@ fn test_with_table( None, true, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -222,6 +227,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), false, false, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -251,6 +257,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), true, false, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -280,6 +287,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), false, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -310,6 +318,7 @@ fn test_with_table( Some(Arc::new(DescriptorTable::new(10))), true, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -340,6 +349,7 @@ fn test_with_table( None, true, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics, )?; @@ -1225,6 +1235,7 @@ fn table_read_fuzz_1() -> crate::Result<()> { Some(Arc::new(crate::DescriptorTable::new(10))), true, true, + crate::comparator::default_comparator(), ) .unwrap(); @@ -1299,6 +1310,7 @@ fn table_partitioned_index() -> crate::Result<()> { Some(Arc::new(crate::DescriptorTable::new(10))), true, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] Default::default(), ) @@ -1409,6 +1421,7 @@ fn table_global_seqno() -> crate::Result<()> { Some(Arc::new(crate::DescriptorTable::new(10))), true, true, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] Default::default(), ) @@ -1452,7 +1465,9 @@ fn rt_block(data: Vec) -> Block { /// with the given field and expected byte offset. fn assert_rt_decode_error(data: Vec, expected_field: &str, expected_offset: u64) { let block = rt_block(data); - match Table::decode_range_tombstones(&block) { + // Uses DefaultUserComparator: tests verify structural decode errors + // (truncation, missing fields), not comparator-dependent ordering. + match Table::decode_range_tombstones(&block, &crate::comparator::DefaultUserComparator) { Err(crate::Error::RangeTombstoneDecode { field, offset }) => { assert_eq!( field, expected_field, @@ -1612,6 +1627,7 @@ fn load_block_range_tombstone_metrics() -> crate::Result<()> { Some(Arc::new(DescriptorTable::new(10))), false, false, + crate::comparator::default_comparator(), #[cfg(feature = "metrics")] metrics.clone(), )?; diff --git a/src/table/util.rs b/src/table/util.rs index a558e9c82..4e12644c0 100644 --- a/src/table/util.rs +++ b/src/table/util.rs @@ -146,8 +146,61 @@ pub fn longest_shared_prefix_length(s1: &[u8], s2: &[u8]) -> usize { .count() } +/// Compares the conceptual concatenation `prefix + suffix` against `needle` +/// using the given comparator. +/// +/// For the default lexicographic comparator this performs a zero-allocation +/// bytewise comparison. Custom comparators fall back to concatenating prefix +/// and suffix into a temporary `Vec` so that `UserComparator::compare` always +/// receives a complete key. +#[must_use] +pub fn compare_prefixed_slice( + prefix: &[u8], + suffix: &[u8], + needle: &[u8], + cmp: &dyn crate::comparator::UserComparator, +) -> std::cmp::Ordering { + // Fast path: zero-allocation bytewise comparison for the default + // (lexicographic) comparator. This is the hot path for block index + // and data block binary searches. + if cmp.is_lexicographic() { + return compare_prefixed_slice_lexicographic(prefix, suffix, needle); + } + + // Slow path: materialize prefix+suffix into a contiguous buffer for + // custom comparators. Uses a stack buffer for typical key sizes to + // avoid heap allocation on the hot binary-search path. + let total_len = prefix.len() + suffix.len(); + + if total_len <= 256 { + let mut buf = [0_u8; 256]; + + // SAFETY (indexing): total_len <= 256 == buf.len(), and + // prefix.len() + suffix.len() == total_len, so all slices are in bounds. + #[expect(clippy::indexing_slicing, reason = "total_len <= 256 checked above")] + { + buf[..prefix.len()].copy_from_slice(prefix); + buf[prefix.len()..total_len].copy_from_slice(suffix); + } + + #[expect(clippy::indexing_slicing, reason = "total_len <= 256 checked above")] + return cmp.compare(&buf[..total_len], needle); + } + + // Fallback for unusually large keys: allocate a temporary Vec. + let mut full_key = Vec::with_capacity(total_len); + full_key.extend_from_slice(prefix); + full_key.extend_from_slice(suffix); + cmp.compare(&full_key, needle) +} + +/// Zero-allocation lexicographic comparison of `prefix + suffix` against `needle`. #[must_use] -pub fn compare_prefixed_slice(prefix: &[u8], suffix: &[u8], needle: &[u8]) -> std::cmp::Ordering { +fn compare_prefixed_slice_lexicographic( + prefix: &[u8], + suffix: &[u8], + needle: &[u8], +) -> std::cmp::Ordering { use std::cmp::Ordering::{Equal, Greater}; if needle.is_empty() { @@ -158,13 +211,21 @@ pub fn compare_prefixed_slice(prefix: &[u8], suffix: &[u8], needle: &[u8]) -> st let max_pfx_len = prefix.len().min(needle.len()); { - #[expect(unsafe_code, reason = "We checked for max_pfx_len")] - let prefix = unsafe { prefix.get_unchecked(0..max_pfx_len) }; - - #[expect(unsafe_code, reason = "We checked for max_pfx_len")] - let needle = unsafe { needle.get_unchecked(0..max_pfx_len) }; - - match prefix.cmp(needle) { + // SAFETY: max_pfx_len = min(prefix.len(), needle.len()), so both + // slices [0..max_pfx_len] are within bounds by construction. + #[expect( + unsafe_code, + reason = "max_pfx_len <= prefix.len() && max_pfx_len <= needle.len()" + )] + let pfx = unsafe { prefix.get_unchecked(0..max_pfx_len) }; + + #[expect( + unsafe_code, + reason = "max_pfx_len <= prefix.len() && max_pfx_len <= needle.len()" + )] + let ndl = unsafe { needle.get_unchecked(0..max_pfx_len) }; + + match pfx.cmp(ndl) { Equal => {} ordering => return ordering, } @@ -175,17 +236,20 @@ pub fn compare_prefixed_slice(prefix: &[u8], suffix: &[u8], needle: &[u8]) -> st return Greater; } + // SAFETY: rest_len == 0 means prefix.len() <= needle.len(), so + // max_pfx_len == prefix.len() <= needle.len() and needle[max_pfx_len..] is in-bounds. #[expect( unsafe_code, - reason = "We know that the prefix is definitely not longer than the needle so we can safely truncate" + reason = "max_pfx_len <= needle.len() guaranteed by rest_len == 0 guard above" )] - let needle = unsafe { needle.get_unchecked(max_pfx_len..) }; - suffix.cmp(needle) + let remaining_needle = unsafe { needle.get_unchecked(max_pfx_len..) }; + suffix.cmp(remaining_needle) } #[cfg(test)] mod tests { use super::*; + use crate::comparator::DefaultUserComparator; use test_log::test; #[test] @@ -205,39 +269,147 @@ mod tests { fn test_compare_prefixed_slice() { use std::cmp::Ordering::{Equal, Greater, Less}; - assert_eq!(Greater, compare_prefixed_slice(&[0, 161], &[], &[0])); - - assert_eq!(Equal, compare_prefixed_slice(b"abc", b"xyz", b"abcxyz")); - assert_eq!(Equal, compare_prefixed_slice(b"abc", b"", b"abc")); - assert_eq!(Equal, compare_prefixed_slice(b"abc", b"abc", b"abcabc")); - assert_eq!(Equal, compare_prefixed_slice(b"", b"", b"")); - assert_eq!(Less, compare_prefixed_slice(b"a", b"", b"y")); - assert_eq!(Less, compare_prefixed_slice(b"a", b"", b"yyy")); - assert_eq!(Less, compare_prefixed_slice(b"a", b"", b"yyy")); - assert_eq!(Less, compare_prefixed_slice(b"yyyy", b"a", b"yyyyb")); - assert_eq!(Less, compare_prefixed_slice(b"yyy", b"b", b"yyyyb")); - assert_eq!(Less, compare_prefixed_slice(b"abc", b"d", b"abce")); - assert_eq!(Less, compare_prefixed_slice(b"ab", b"", b"ac")); - assert_eq!(Greater, compare_prefixed_slice(b"a", b"", b"")); - assert_eq!(Greater, compare_prefixed_slice(b"", b"a", b"")); - assert_eq!(Greater, compare_prefixed_slice(b"a", b"a", b"")); - assert_eq!(Greater, compare_prefixed_slice(b"b", b"a", b"a")); - assert_eq!(Greater, compare_prefixed_slice(b"a", b"b", b"a")); - assert_eq!(Greater, compare_prefixed_slice(b"abc", b"xy", b"abcw")); - assert_eq!(Greater, compare_prefixed_slice(b"ab", b"cde", b"a")); - assert_eq!(Greater, compare_prefixed_slice(b"abcd", b"zz", b"abc")); - assert_eq!(Greater, compare_prefixed_slice(b"abc", b"d", b"abc")); assert_eq!( Greater, - compare_prefixed_slice(b"aaaa", b"aaab", b"aaaaaaaa") + compare_prefixed_slice(&[0, 161], &[], &[0], &DefaultUserComparator) + ); + + assert_eq!( + Equal, + compare_prefixed_slice(b"abc", b"xyz", b"abcxyz", &DefaultUserComparator) + ); + assert_eq!( + Equal, + compare_prefixed_slice(b"abc", b"", b"abc", &DefaultUserComparator) + ); + assert_eq!( + Equal, + compare_prefixed_slice(b"abc", b"abc", b"abcabc", &DefaultUserComparator) + ); + assert_eq!( + Equal, + compare_prefixed_slice(b"", b"", b"", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"a", b"", b"y", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"a", b"", b"yyy", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"a", b"", b"yyy", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"yyyy", b"a", b"yyyyb", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"yyy", b"b", b"yyyyb", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"abc", b"d", b"abce", &DefaultUserComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"ab", b"", b"ac", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"a", b"", b"", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"", b"a", b"", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"a", b"a", b"", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"b", b"a", b"a", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"a", b"b", b"a", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"abc", b"xy", b"abcw", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"ab", b"cde", b"a", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"abcd", b"zz", b"abc", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"abc", b"d", b"abc", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"aaaa", b"aaab", b"aaaaaaaa", &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(b"aaaa", b"aaba", b"aaaaaaaa", &DefaultUserComparator) ); assert_eq!( Greater, - compare_prefixed_slice(b"aaaa", b"aaba", b"aaaaaaaa") + compare_prefixed_slice(b"abcd", b"x", b"abc", &DefaultUserComparator) ); - assert_eq!(Greater, compare_prefixed_slice(b"abcd", b"x", b"abc")); - assert_eq!(Less, compare_prefixed_slice(&[0x7F], &[], &[0x80])); - assert_eq!(Greater, compare_prefixed_slice(&[0xFF], &[], &[0x10])); + assert_eq!( + Less, + compare_prefixed_slice(&[0x7F], &[], &[0x80], &DefaultUserComparator) + ); + assert_eq!( + Greater, + compare_prefixed_slice(&[0xFF], &[], &[0x10], &DefaultUserComparator) + ); + } + + /// Reverse comparator to exercise the Vec-allocation slow path. + struct ReverseComparator; + impl crate::comparator::UserComparator for ReverseComparator { + fn compare(&self, a: &[u8], b: &[u8]) -> std::cmp::Ordering { + b.cmp(a) + } + } + + #[test] + fn test_compare_prefixed_slice_custom_comparator() { + use std::cmp::Ordering::{Equal, Greater, Less}; + + // With reverse comparator, "abc" > "xyz" (reversed) + assert_eq!( + Greater, + compare_prefixed_slice(b"ab", b"c", b"xyz", &ReverseComparator) + ); + assert_eq!( + Less, + compare_prefixed_slice(b"xy", b"z", b"abc", &ReverseComparator) + ); + assert_eq!( + Equal, + compare_prefixed_slice(b"ab", b"c", b"abc", &ReverseComparator) + ); + // Empty cases + assert_eq!( + Equal, + compare_prefixed_slice(b"", b"", b"", &ReverseComparator) + ); + assert_eq!( + Less, // reversed: non-empty > empty + compare_prefixed_slice(b"a", b"", b"", &ReverseComparator) + ); } } diff --git a/src/tree/ingest.rs b/src/tree/ingest.rs index 09ebe8d41..a5b016771 100644 --- a/src/tree/ingest.rs +++ b/src/tree/ingest.rs @@ -302,6 +302,7 @@ impl<'a> Ingestion<'a> { self.tree.config.descriptor_table.clone(), false, false, + self.tree.config.comparator.clone(), #[cfg(feature = "metrics")] self.tree.metrics.clone(), ) diff --git a/src/tree/inner.rs b/src/tree/inner.rs index b186dd0fb..8bee97659 100644 --- a/src/tree/inner.rs +++ b/src/tree/inner.rs @@ -82,13 +82,15 @@ impl TreeInner { ); persist_version(&config.path, &version)?; + let comparator = config.comparator.clone(); + Ok(Self { id: get_next_tree_id(), memtable_id_counter: SequenceNumberCounter::new(1), table_id_counter: SequenceNumberCounter::default(), blob_file_id_counter: SequenceNumberCounter::default(), config: Arc::new(config), - version_history: Arc::new(RwLock::new(SuperVersions::new(version))), + version_history: Arc::new(RwLock::new(SuperVersions::new(version, comparator))), stop_signal: StopSignal::default(), major_compaction_lock: RwLock::default(), flush_lock: Mutex::default(), diff --git a/src/tree/mod.rs b/src/tree/mod.rs index 5b9b58338..03bf784dd 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -123,15 +123,18 @@ impl AbstractTree for Tree { let key = Slice::from(key); - for kv in super_version - .active_memtable - .range(InternalKey::new(key.clone(), SeqNo::MAX, ValueType::Value)..) - { + for kv in super_version.active_memtable.range_internal(( + Bound::Included(InternalKey::new(key.clone(), SeqNo::MAX, ValueType::Value)), + Bound::Unbounded, + )) { log::info!("[Active] {kv:?}"); } for mt in super_version.sealed_memtables.iter().rev() { - for kv in mt.range(InternalKey::new(key.clone(), SeqNo::MAX, ValueType::Value)..) { + for kv in mt.range_internal(( + Bound::Included(InternalKey::new(key.clone(), SeqNo::MAX, ValueType::Value)), + Bound::Unbounded, + )) { log::info!("[Sealed #{}] {kv:?}", mt.id()); } } @@ -140,7 +143,7 @@ impl AbstractTree for Tree { .version .iter_levels() .flat_map(|lvl| lvl.iter()) - .filter_map(|run| run.get_for_key(&key)) + .filter_map(|run| run.get_for_key_cmp(&key, self.config.comparator.as_ref())) { for kv in table.range(..) { let kv = kv?; @@ -164,7 +167,12 @@ impl AbstractTree for Tree { .expect("lock is poisoned") .get_version_for_snapshot(seqno); - Self::get_internal_entry_from_version(&super_version, key, seqno) + Self::get_internal_entry_from_version( + &super_version, + key, + seqno, + self.config.comparator.as_ref(), + ) } fn current_version(&self) -> Version { @@ -270,7 +278,10 @@ impl AbstractTree for Tree { &self.config.path, |v| { let mut copy = v.clone(); - copy.active_memtable = Arc::new(Memtable::new(self.memtable_id_counter.next())); + copy.active_memtable = Arc::new(Memtable::new( + self.memtable_id_counter.next(), + self.config.comparator.clone(), + )); copy.sealed_memtables = Arc::default(); copy.version = Version::new(v.version.id() + 1, self.tree_type()); Ok(copy) @@ -430,6 +441,7 @@ impl AbstractTree for Tree { self.config.descriptor_table.clone(), pin_filter, pin_index, + self.config.comparator.clone(), #[cfg(feature = "metrics")] self.metrics.clone(), ) @@ -503,7 +515,10 @@ impl AbstractTree for Tree { } let mut copy = version_history_lock.latest_version(); - copy.active_memtable = Arc::new(Memtable::new(self.memtable_id_counter.next())); + copy.active_memtable = Arc::new(Memtable::new( + self.memtable_id_counter.next(), + self.config.comparator.clone(), + )); copy.sealed_memtables = Arc::new(SealedMemtables::default()); // Rotate does not modify the memtable, so it cannot break snapshots @@ -566,7 +581,10 @@ impl AbstractTree for Tree { let yanked_memtable = super_version.active_memtable; let mut copy = version_history_lock.latest_version(); - copy.active_memtable = Arc::new(Memtable::new(self.memtable_id_counter.next())); + copy.active_memtable = Arc::new(Memtable::new( + self.memtable_id_counter.next(), + self.config.comparator.clone(), + )); copy.sealed_memtables = Arc::new(super_version.sealed_memtables.add(yanked_memtable.clone())); @@ -667,6 +685,7 @@ impl AbstractTree for Tree { key, seqno, self.config.merge_operator.as_ref(), + self.config.comparator.as_ref(), ) } @@ -684,6 +703,7 @@ impl AbstractTree for Tree { key.as_ref(), seqno, self.config.merge_operator.as_ref(), + self.config.comparator.as_ref(), ) }) .collect() @@ -742,8 +762,9 @@ impl Tree { key: &[u8], seqno: SeqNo, merge_operator: Option<&Arc>, + comparator: &dyn crate::comparator::UserComparator, ) -> crate::Result> { - let entry = Self::get_internal_entry_from_version(super_version, key, seqno)?; + let entry = Self::get_internal_entry_from_version(super_version, key, seqno, comparator)?; match entry { Some(entry) if entry.key.value_type == ValueType::MergeOperand => { @@ -762,6 +783,7 @@ impl Tree { key, entry.key.seqno, seqno, + comparator, ) { Ok(None) } else { @@ -793,11 +815,13 @@ impl Tree { let key_hash = crate::table::filter::standard_bloom::Builder::get_hash(key); let key_slice = crate::Slice::from(key); let range = key_slice.clone()..=key_slice; + let comparator = version.active_memtable.comparator.clone(); let iter_state = IterState { version, ephemeral: None, merge_operator: Some(merge_operator), + comparator, prefix_hash: None, key_hash: Some(key_hash), #[cfg(feature = "metrics")] @@ -823,6 +847,7 @@ impl Tree { seqno: SeqNo, ephemeral: Option<(Arc, SeqNo)>, merge_operator: Option>, + comparator: crate::comparator::SharedComparator, ) -> impl DoubleEndedIterator> + 'static { Self::create_internal_range_with_prefix_hash( version, @@ -830,6 +855,7 @@ impl Tree { seqno, ephemeral, merge_operator, + comparator, None, ) } @@ -847,6 +873,7 @@ impl Tree { seqno: SeqNo, ephemeral: Option<(Arc, SeqNo)>, merge_operator: Option>, + comparator: crate::comparator::SharedComparator, prefix_hash: Option, ) -> impl DoubleEndedIterator> + 'static { use crate::range::{IterState, TreeIter}; @@ -870,6 +897,7 @@ impl Tree { version, ephemeral, merge_operator, + comparator, prefix_hash, key_hash: None, #[cfg(feature = "metrics")] @@ -883,6 +911,7 @@ impl Tree { super_version: &SuperVersion, key: &[u8], seqno: SeqNo, + comparator: &dyn crate::comparator::UserComparator, ) -> crate::Result> { // Search order: active → sealed → SST (newest first). A point // tombstone in a newer source is authoritative — no older source @@ -893,7 +922,13 @@ impl Tree { }; // Check if any range tombstone suppresses this entry - if Self::is_suppressed_by_range_tombstones(super_version, key, entry.key.seqno, seqno) { + if Self::is_suppressed_by_range_tombstones( + super_version, + key, + entry.key.seqno, + seqno, + comparator, + ) { return Ok(None); } return Ok(Some(entry)); @@ -907,17 +942,30 @@ impl Tree { return Ok(None); }; - if Self::is_suppressed_by_range_tombstones(super_version, key, entry.key.seqno, seqno) { + if Self::is_suppressed_by_range_tombstones( + super_version, + key, + entry.key.seqno, + seqno, + comparator, + ) { return Ok(None); } return Ok(Some(entry)); } // Now look in tables... this may involve disk I/O - let entry = Self::get_internal_entry_from_tables(&super_version.version, key, seqno)?; + let entry = + Self::get_internal_entry_from_tables(&super_version.version, key, seqno, comparator)?; if let Some(entry) = entry { - if Self::is_suppressed_by_range_tombstones(super_version, key, entry.key.seqno, seqno) { + if Self::is_suppressed_by_range_tombstones( + super_version, + key, + entry.key.seqno, + seqno, + comparator, + ) { return Ok(None); } return Ok(Some(entry)); @@ -933,6 +981,7 @@ impl Tree { key: &[u8], key_seqno: SeqNo, read_seqno: SeqNo, + comparator: &dyn crate::comparator::UserComparator, ) -> bool { // Check active memtable range tombstones. // Future optimization: skip lock when memtable has no RTs (atomic count). @@ -952,29 +1001,38 @@ impl Tree { // Check SST table range tombstones. // - // Flush/RT-only writes widen persisted table key ranges to include RT - // coverage, and compaction either clips RTs to the output table range - // or widens metadata in the inclusive-upper-bound fallback. That makes - // `metadata.key_range.contains_key(key)` a sound early reject here and - // avoids scanning RT blocks for unrelated SSTs on point reads. - // - // Per-table RT lists are sorted by start key on load, + // Per-table RT lists are sorted by start key (using comparator) on load, // so binary search narrows candidates to RTs with start <= key. + // The key_range early reject uses the comparator so it works with + // non-lexicographic orderings. for table in super_version .version .iter_levels() .flat_map(|lvl| lvl.iter()) .flat_map(|run| run.iter()) .filter(|t| !t.range_tombstones().is_empty()) - .filter(|t| t.metadata.key_range.contains_key(key)) + .filter(|t| { + // Early reject: skip tables whose key range doesn't contain the key. + let kr = &t.metadata.key_range; + comparator.compare(kr.min(), key) != std::cmp::Ordering::Greater + && comparator.compare(key, kr.max()) != std::cmp::Ordering::Greater + }) { let rts = table.range_tombstones(); - let candidate_end = rts.partition_point(|rt| rt.start.as_ref() <= key); + + // Binary search: find the first RT whose start is > key (in comparator order). + // All RTs before that index have start <= key and are candidates. + let candidate_end = rts.partition_point(|rt| { + comparator.compare(&rt.start, key) != std::cmp::Ordering::Greater + }); for rt in rts.iter().take(candidate_end) { - // Binary search already narrowed to start <= key; should_suppress - // re-checks contains_key (harmless) and avoids semantic drift. - if rt.should_suppress(key, key_seqno, read_seqno) { + // Check: start <= key < end (in comparator order) AND seqno visibility. + if rt.visible_at(read_seqno) + && comparator.compare(&rt.start, key) != std::cmp::Ordering::Greater + && comparator.compare(key, &rt.end) == std::cmp::Ordering::Less + && key_seqno < rt.seqno + { return true; } } @@ -987,6 +1045,7 @@ impl Tree { version: &Version, key: &[u8], seqno: SeqNo, + comparator: &dyn crate::comparator::UserComparator, ) -> crate::Result> { // NOTE: Create key hash for hash sharing // https://fjall-rs.github.io/post/bloom-filter-hash-sharing/ @@ -1006,7 +1065,7 @@ impl Tree { let mut best: Option = None; for run in level.iter() { - if let Some(table) = run.get_for_key(key) { + if let Some(table) = run.get_for_key_cmp(key, comparator) { if let Some(item) = table.get(key, seqno, key_hash)? { match &best { // >= keeps first-seen on tie. Seqno is monotonically @@ -1031,7 +1090,7 @@ impl Tree { } } else { for run in level.iter() { - if let Some(table) = run.get_for_key(key) { + if let Some(table) = run.get_for_key_cmp(key, comparator) { if let Some(item) = table.get(key, seqno, key_hash)? { return Ok(ignore_tombstone_value(item)); } @@ -1192,6 +1251,7 @@ impl Tree { seqno, ephemeral, self.config.merge_operator.clone(), + self.config.comparator.clone(), ) .map(|item| match item { Ok(kv) => Ok((kv.key.user_key, kv.value)), @@ -1226,6 +1286,7 @@ impl Tree { version: super_version, ephemeral, merge_operator: self.config.merge_operator.clone(), + comparator: self.config.comparator.clone(), prefix_hash, key_hash: None, #[cfg(feature = "metrics")] @@ -1309,12 +1370,14 @@ impl Tree { .max() .unwrap_or_default(); + let comparator = config.comparator.clone(); + let inner = TreeInner { id: tree_id, memtable_id_counter: SequenceNumberCounter::new(1), table_id_counter: SequenceNumberCounter::new(highest_table_id + 1), blob_file_id_counter: SequenceNumberCounter::default(), - version_history: Arc::new(RwLock::new(SuperVersions::new(version))), + version_history: Arc::new(RwLock::new(SuperVersions::new(version, comparator))), stop_signal: StopSignal::default(), config: Arc::new(config), major_compaction_lock: RwLock::default(), @@ -1455,6 +1518,7 @@ impl Tree { config.descriptor_table.clone(), pin_filter, pin_index, + config.comparator.clone(), #[cfg(feature = "metrics")] metrics.clone(), )?; diff --git a/src/version/run.rs b/src/version/run.rs index 6f81a3bc0..06786c330 100644 --- a/src/version/run.rs +++ b/src/version/run.rs @@ -103,6 +103,27 @@ impl Run { self.0.get(idx).filter(|x| x.key_range().min() <= &key) } + /// Like [`get_for_key`], but uses a custom comparator for key ordering. + /// + /// # Precondition (guaranteed by construction) + /// + /// Tables within a run are sorted by `key_range` in comparator order. + /// This holds because tables are flushed from comparator-sorted memtables + /// and compaction preserves the ordering. The binary search here must + /// use the same comparator to maintain the invariant. + pub fn get_for_key_cmp( + &self, + key: &[u8], + cmp: &dyn crate::comparator::UserComparator, + ) -> Option<&T> { + let idx = self + .partition_point(|x| cmp.compare(x.key_range().max(), key) == std::cmp::Ordering::Less); + + self.0 + .get(idx) + .filter(|x| cmp.compare(x.key_range().min(), key) != std::cmp::Ordering::Greater) + } + /// Returns the run's key range. pub fn aggregate_key_range(&self) -> KeyRange { #[expect(clippy::expect_used, reason = "by definition, runs are never empty")] diff --git a/src/version/super_version.rs b/src/version/super_version.rs index e1b6dd6a5..b0c11350c 100644 --- a/src/version/super_version.rs +++ b/src/version/super_version.rs @@ -3,6 +3,7 @@ // (found in the LICENSE-* files in the repository) use crate::{ + comparator::SharedComparator, memtable::Memtable, tree::sealed::SealedMemtables, version::{persist_version, Version}, @@ -29,10 +30,10 @@ pub struct SuperVersion { pub struct SuperVersions(VecDeque); impl SuperVersions { - pub fn new(version: Version) -> Self { + pub fn new(version: Version, comparator: SharedComparator) -> Self { Self( vec![SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(Memtable::new(0, comparator)), sealed_memtables: Arc::default(), version, seqno: 0, @@ -202,26 +203,31 @@ impl SuperVersions { #[cfg(test)] mod tests { use super::*; + use crate::comparator::default_comparator; use test_log::test; + fn new_memtable(id: u64) -> Memtable { + Memtable::new(id, default_comparator()) + } + #[test] fn super_version_gc_above_watermark() -> crate::Result<()> { let mut history = SuperVersions( vec![ SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 0, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 1, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 2, @@ -242,19 +248,19 @@ mod tests { let mut history = SuperVersions( vec![ SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 0, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 1, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 2, @@ -275,25 +281,25 @@ mod tests { let mut history = SuperVersions( vec![ SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 0, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 1, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 2, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 8, @@ -314,13 +320,13 @@ mod tests { let mut history = SuperVersions( vec![ SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 0, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 8, @@ -341,13 +347,13 @@ mod tests { let mut history = SuperVersions( vec![ SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 0, }, SuperVersion { - active_memtable: Arc::new(Memtable::new(0)), + active_memtable: Arc::new(new_memtable(0)), sealed_memtables: Arc::default(), version: Version::new(0, crate::TreeType::Standard), seqno: 2, diff --git a/src/vlog/blob_file/meta.rs b/src/vlog/blob_file/meta.rs index a1c40de63..18c8e1f18 100644 --- a/src/vlog/blob_file/meta.rs +++ b/src/vlog/blob_file/meta.rs @@ -5,6 +5,7 @@ use crate::{ checksum::ChecksumType, coding::{Decode, Encode}, + comparator::default_comparator, table::{Block, DataBlock}, vlog::BlobFileId, CompressionType, InternalValue, KeyRange, SeqNo, Slice, @@ -13,9 +14,9 @@ use byteorder::{LittleEndian, ReadBytesExt}; use std::io::{Read, Write}; macro_rules! read_u64 { - ($block:expr, $name:expr) => {{ + ($block:expr, $name:expr, $cmp:expr) => {{ let bytes = $block - .point_read($name, SeqNo::MAX) + .point_read($name, SeqNo::MAX, $cmp) .ok_or(crate::Error::InvalidHeader("BlobFileMeta"))?; let mut bytes = &bytes.value[..]; @@ -24,9 +25,9 @@ macro_rules! read_u64 { } macro_rules! read_u128 { - ($block:expr, $name:expr) => {{ + ($block:expr, $name:expr, $cmp:expr) => {{ let bytes = $block - .point_read($name, SeqNo::MAX) + .point_read($name, SeqNo::MAX, $cmp) .ok_or(crate::Error::InvalidHeader("BlobFileMeta"))?; let mut bytes = &bytes.value[..]; @@ -122,9 +123,12 @@ impl Metadata { let block = Block::from_reader(reader, CompressionType::None)?; let block = DataBlock::new(block); + // Metadata keys are always lexicographic, so use the default comparator. + let cmp = default_comparator(); + let version = { let bytes = block - .point_read(b"blob_file_version", SeqNo::MAX) + .point_read(b"blob_file_version", SeqNo::MAX, &cmp) .ok_or(crate::Error::InvalidHeader("BlobFileMeta"))?; *bytes .value @@ -140,15 +144,15 @@ impl Metadata { _ => return Err(crate::Error::InvalidHeader("BlobFileMeta")), } - let id = read_u64!(block, b"id"); - let created_at = read_u128!(block, b"created_at"); - let item_count = read_u64!(block, b"item_count"); - let file_size = read_u64!(block, b"file_size"); - let total_uncompressed_bytes = read_u64!(block, b"uncompressed_size"); + let id = read_u64!(block, b"id", &cmp); + let created_at = read_u128!(block, b"created_at", &cmp); + let item_count = read_u64!(block, b"item_count", &cmp); + let file_size = read_u64!(block, b"file_size", &cmp); + let total_uncompressed_bytes = read_u64!(block, b"uncompressed_size", &cmp); let compression = { let bytes = block - .point_read(b"compression", SeqNo::MAX) + .point_read(b"compression", SeqNo::MAX, &cmp) .ok_or(crate::Error::InvalidHeader("BlobFileMeta"))?; let mut bytes = &bytes.value[..]; @@ -157,11 +161,11 @@ impl Metadata { let key_range = KeyRange::new(( block - .point_read(b"key#min", SeqNo::MAX) + .point_read(b"key#min", SeqNo::MAX, &cmp) .ok_or(crate::Error::InvalidHeader("BlobFileMeta"))? .value, block - .point_read(b"key#max", SeqNo::MAX) + .point_read(b"key#max", SeqNo::MAX, &cmp) .ok_or(crate::Error::InvalidHeader("BlobFileMeta"))? .value, )); diff --git a/tests/custom_comparator.rs b/tests/custom_comparator.rs new file mode 100644 index 000000000..71ce102ac --- /dev/null +++ b/tests/custom_comparator.rs @@ -0,0 +1,281 @@ +use lsm_tree::{AbstractTree, Config, Guard as _, SharedComparator, UserComparator}; +use std::cmp::Ordering; +use std::sync::Arc; + +/// Comparator that reverses the default lexicographic byte ordering. +struct ReverseComparator; + +impl UserComparator for ReverseComparator { + fn compare(&self, a: &[u8], b: &[u8]) -> Ordering { + b.cmp(a) // reversed + } +} + +/// Comparator that orders u64 keys stored as big-endian bytes. +struct U64BigEndianComparator; + +impl UserComparator for U64BigEndianComparator { + fn compare(&self, a: &[u8], b: &[u8]) -> Ordering { + if a.len() == 8 && b.len() == 8 { + let a_u64 = u64::from_be_bytes(a.try_into().unwrap()); + let b_u64 = u64::from_be_bytes(b.try_into().unwrap()); + a_u64.cmp(&b_u64) + } else { + // Non-8-byte keys: fall back to lexicographic ordering + // to preserve the bytewise-equality invariant. + a.cmp(b) + } + } +} + +#[test] +fn reverse_comparator_point_read() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(ReverseComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + tree.insert("a", "val_a", 0); + tree.insert("b", "val_b", 1); + tree.insert("c", "val_c", 2); + + // Point reads should work regardless of comparator + assert_eq!(tree.get("a", 3)?, Some("val_a".as_bytes().into())); + assert_eq!(tree.get("b", 3)?, Some("val_b".as_bytes().into())); + assert_eq!(tree.get("c", 3)?, Some("val_c".as_bytes().into())); + + Ok(()) +} + +#[test] +fn reverse_comparator_iteration_order() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(ReverseComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + tree.insert("a", "val_a", 0); + tree.insert("b", "val_b", 1); + tree.insert("c", "val_c", 2); + + // With reverse comparator, iteration order should be c, b, a + let items: Vec<_> = tree + .iter(3, None) + .map(|g| { + let (k, v) = g.into_inner().unwrap(); + ( + String::from_utf8(k.to_vec()).unwrap(), + String::from_utf8(v.to_vec()).unwrap(), + ) + }) + .collect(); + + assert_eq!( + items, + vec![ + ("c".into(), "val_c".into()), + ("b".into(), "val_b".into()), + ("a".into(), "val_a".into()), + ] + ); + + Ok(()) +} + +#[test] +fn reverse_comparator_after_flush() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(ReverseComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + tree.insert("a", "val_a", 0); + tree.insert("b", "val_b", 1); + tree.insert("c", "val_c", 2); + + // Flush to disk + tree.flush_active_memtable(3)?; + + // Point reads after flush + assert_eq!(tree.get("a", 4)?, Some("val_a".as_bytes().into())); + assert_eq!(tree.get("b", 4)?, Some("val_b".as_bytes().into())); + assert_eq!(tree.get("c", 4)?, Some("val_c".as_bytes().into())); + + // Iteration order should still be reversed after flush + let items: Vec<_> = tree + .iter(4, None) + .map(|g| { + let (k, _) = g.into_inner().unwrap(); + String::from_utf8(k.to_vec()).unwrap() + }) + .collect(); + + assert_eq!(items, vec!["c", "b", "a"]); + + Ok(()) +} + +#[test] +fn u64_comparator_point_read_and_order() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(U64BigEndianComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + // Insert u64 keys as big-endian bytes + let keys = [1u64, 100, 50, 1000, 500]; + for (i, &key) in keys.iter().enumerate() { + tree.insert(key.to_be_bytes(), format!("val_{key}"), i as u64); + } + + // Point reads + assert_eq!( + tree.get(1u64.to_be_bytes().as_ref(), 5)?, + Some("val_1".as_bytes().into()) + ); + assert_eq!( + tree.get(1000u64.to_be_bytes().as_ref(), 5)?, + Some("val_1000".as_bytes().into()) + ); + + // Iteration should be in numeric order: 1, 50, 100, 500, 1000 + let items: Vec = tree + .iter(5, None) + .map(|g| { + let (k, _) = g.into_inner().unwrap(); + u64::from_be_bytes(k[..8].try_into().unwrap()) + }) + .collect(); + + assert_eq!(items, vec![1, 50, 100, 500, 1000]); + + Ok(()) +} + +#[test] +fn u64_comparator_after_flush() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(U64BigEndianComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + let keys = [1u64, 100, 50, 1000, 500]; + for (i, &key) in keys.iter().enumerate() { + tree.insert(key.to_be_bytes(), format!("val_{key}"), i as u64); + } + + tree.flush_active_memtable(5)?; + + // Point reads after flush + assert_eq!( + tree.get(50u64.to_be_bytes().as_ref(), 6)?, + Some("val_50".as_bytes().into()) + ); + + // Iteration should still be in numeric order + let items: Vec = tree + .iter(6, None) + .map(|g| { + let (k, _) = g.into_inner().unwrap(); + u64::from_be_bytes(k[..8].try_into().unwrap()) + }) + .collect(); + + assert_eq!(items, vec![1, 50, 100, 500, 1000]); + + Ok(()) +} + +#[test] +fn default_comparator_unchanged_behavior() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + + // No custom comparator — default lexicographic should work as before + let tree = Config::new(folder, Default::default(), Default::default()).open()?; + + tree.insert("banana", "b", 0); + tree.insert("apple", "a", 1); + tree.insert("cherry", "c", 2); + + let items: Vec<_> = tree + .iter(3, None) + .map(|g| { + let (k, _) = g.into_inner().unwrap(); + String::from_utf8(k.to_vec()).unwrap() + }) + .collect(); + + assert_eq!(items, vec!["apple", "banana", "cherry"]); + + Ok(()) +} + +#[test] +fn reverse_comparator_bounded_range_scan() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(ReverseComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + tree.insert("a", "1", 0); + tree.insert("b", "2", 1); + tree.insert("c", "3", 2); + tree.insert("d", "4", 3); + tree.insert("e", "5", 4); + + // Reverse order: e, d, c, b, a + // Range "d"..="b" in reverse comparator means: items where cmp says key >= "d" && key <= "b" + // In reverse: "d" < "c" < "b" (reversed), so range "d"..="b" should yield d, c, b + let items: Vec<_> = tree + .range("d"..="b", 5, None) + .map(|g| { + let (k, _) = g.into_inner().unwrap(); + String::from_utf8(k.to_vec()).unwrap() + }) + .collect(); + + assert_eq!(items, vec!["d", "c", "b"]); + + Ok(()) +} + +#[test] +fn u64_comparator_bounded_range_scan() -> lsm_tree::Result<()> { + let folder = tempfile::tempdir()?; + let cmp: SharedComparator = Arc::new(U64BigEndianComparator); + + let tree = Config::new(folder, Default::default(), Default::default()) + .comparator(cmp) + .open()?; + + for &key in &[10u64, 50, 100, 500, 1000] { + tree.insert(key.to_be_bytes(), format!("v{key}"), key); + } + + // Range scan: 50..=500 should yield 50, 100, 500 + let lo = 50u64.to_be_bytes(); + let hi = 500u64.to_be_bytes(); + let items: Vec = tree + .range(lo..=hi, 1001, None) + .map(|g| { + let (k, _) = g.into_inner().unwrap(); + u64::from_be_bytes(k[..8].try_into().unwrap()) + }) + .collect(); + + assert_eq!(items, vec![50, 100, 500]); + + Ok(()) +} diff --git a/tests/range_tombstone_ephemeral.rs b/tests/range_tombstone_ephemeral.rs index c7c14c78d..1fca84d14 100644 --- a/tests/range_tombstone_ephemeral.rs +++ b/tests/range_tombstone_ephemeral.rs @@ -32,7 +32,10 @@ const EPHEMERAL_MT_ID: lsm_tree::MemtableId = 999; /// Build an ephemeral memtable with the given KVs and range tombstones. fn build_ephemeral(kvs: &[(&[u8], &[u8], u64)], rts: &[(&[u8], &[u8], u64)]) -> Arc { - let mt = Arc::new(Memtable::new(EPHEMERAL_MT_ID)); + let mt = Arc::new(Memtable::new( + EPHEMERAL_MT_ID, + std::sync::Arc::new(lsm_tree::DefaultUserComparator), + )); for &(key, val, seqno) in kvs { mt.insert(lsm_tree::InternalValue::from_components( key,