Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
pub(crate) fn count_blocks_for_prefix(&self, prefix: &str) -> usize {
self.root
.sparse_index
.get_block_ids_range(prefix..=prefix)
.get_block_ids_range::<_, K, _>(prefix..=prefix, ..)
.len()
}

Expand Down Expand Up @@ -615,7 +615,10 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
prefix: &str,
) -> Result<impl Iterator<Item = (K, V)>, Box<dyn ChromaError>> {
// Get all block IDs that might contain this prefix
let block_ids = self.root.sparse_index.get_block_ids_range(prefix..=prefix);
let block_ids = self
.root
.sparse_index
.get_block_ids_range::<_, K, _>(prefix..=prefix, ..);

if block_ids.is_empty() {
return Ok(Vec::new().into_iter().flatten());
Expand Down Expand Up @@ -662,7 +665,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
futures::stream::iter(
self.root
.sparse_index
.get_block_ids_range(prefix_range.clone())
.get_block_ids_range(prefix_range.clone(), key_range.clone())
.into_iter()
.map(Ok),
)
Expand Down Expand Up @@ -697,7 +700,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
let block_ids = self
.root
.sparse_index
.get_block_ids_range(prefix_range.clone());
.get_block_ids_range(prefix_range.clone(), key_range.clone());

let block_futures_is_empty = block_ids.is_empty();
let block_futures = block_ids.into_iter().map(|block_id| {
Expand Down Expand Up @@ -809,7 +812,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
let block_ids = self
.root
.sparse_index
.get_block_ids_range(..=prefix)
.get_block_ids_range::<_, K, _>(..=prefix, ..)
.into_iter()
.take_while(|id| id != &last_block_id)
.collect::<Vec<_>>();
Expand Down
169 changes: 161 additions & 8 deletions rust/blockstore/src/arrow/sparse_index.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::key::CompositeKey;
use crate::Key;
use chroma_error::ChromaError;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -439,16 +440,66 @@ impl SparseIndexReader {
result_uuids
}

pub(super) fn get_block_ids_range<'prefix, PrefixRange>(
pub(super) fn get_block_ids_range<'prefix, PrefixRange, K: Key, KeyRange>(
&self,
prefix_range: PrefixRange,
key_range: KeyRange,
) -> Vec<Uuid>
where
PrefixRange: RangeBounds<&'prefix str>,
KeyRange: RangeBounds<K>,
{
let forward = &self.data.forward;

// We do not materialize the last key of each block, so we must check the next block's start key to determine if the current block's end key is within the query range.
// Convert key range bounds to KeyWrapper once.
let key_start: Bound<crate::key::KeyWrapper> = match key_range.start_bound() {
Bound::Included(k) => Bound::Included(k.clone().into()),
Bound::Excluded(k) => Bound::Excluded(k.clone().into()),
Bound::Unbounded => Bound::Unbounded,
};
let key_end: Bound<crate::key::KeyWrapper> = match key_range.end_bound() {
Bound::Included(k) => Bound::Included(k.clone().into()),
Bound::Excluded(k) => Bound::Excluded(k.clone().into()),
Bound::Unbounded => Bound::Unbounded,
};

// Fast path: when prefix is a single value and key start is
// bounded, use BTreeMap range lookups for O(log B + R) instead
// of the O(B) linear scan below.
if let (Bound::Included(ps), Bound::Included(pe)) =
(prefix_range.start_bound(), prefix_range.end_bound())
{
if ps == pe {
if let Bound::Included(ref ks) = key_start {
let start = SparseIndexDelimiter::Key(CompositeKey {
prefix: ps.to_string(),
key: ks.clone(),
});
if let Some((first_delim, _)) = forward.range(..=&start).next_back() {
return forward
.range(first_delim.clone()..)
.take_while(|(delim, _)| match delim {
SparseIndexDelimiter::Start => true,
SparseIndexDelimiter::Key(k) => {
if k.prefix.as_str() != *ps {
return false;
}
match &key_end {
Bound::Included(ke) => k.key <= *ke,
Bound::Excluded(ke) => k.key < *ke,
Bound::Unbounded => true,
}
}
})
.map(|(_, v)| v.id)
.collect();
}
return vec![];
}
}
}

// Slow path: linear scan with prefix + key overlap filtering.
let start_keys_offset_by_1_iter = forward
.iter()
.skip(1)
Expand Down Expand Up @@ -502,14 +553,50 @@ impl SparseIndexReader {
};

// Check whether max_start_prefix <= min_end_prefix
match (max_start_prefix, min_end_prefix) {
let prefix_overlap = match (max_start_prefix, min_end_prefix) {
(Bound::Included(start), Bound::Included(end)) => start <= end,
(Bound::Included(start), Bound::Excluded(end))
| (Bound::Excluded(start), Bound::Included(end))
| (Bound::Excluded(start), Bound::Excluded(end)) => start < end,
// At least one of these is unbounded.
_ => true,
};
if !prefix_overlap {
return false;
}

// When block start and end share the same prefix, further
// check whether the block's key range [bs.key, be.key)
// overlaps the query key range.
if let (SparseIndexDelimiter::Key(bs), Some(be)) =
(block_start_delimiter, block_end_delimiter)
{
if bs.prefix == be.prefix {
// Block key range is [bs.key, be.key) (end exclusive).
// Compute overlap with query key range.
let max_start = match &key_start {
Bound::Included(ks) | Bound::Excluded(ks) if bs.key <= *ks => {
key_start.as_ref()
}
_ => Bound::Included(&bs.key),
};
let min_end = match &key_end {
Bound::Included(ke) | Bound::Excluded(ke) if be.key > *ke => {
key_end.as_ref()
}
_ => Bound::Excluded(&be.key),
};
return match (max_start, min_end) {
(Bound::Included(s), Bound::Included(e)) => s <= e,
(Bound::Included(s), Bound::Excluded(e))
| (Bound::Excluded(s), Bound::Included(e))
| (Bound::Excluded(s), Bound::Excluded(e)) => s < e,
_ => true,
};
}
}

true
})
.map(|(sparse_index_value, _, _)| sparse_index_value.id)
.collect()
Expand Down Expand Up @@ -856,27 +943,93 @@ mod tests {
.expect("Set count should succeed");

let reader = writer.to_reader().expect("Conversion should succeed");
let blocks = reader.get_block_ids_range(..);
let blocks = reader.get_block_ids_range::<_, &str, _>(.., ..);
assert_eq!(
blocks,
vec![
block_id_0, block_id_1, block_id_2, block_id_3, block_id_4, block_id_5, block_id_6
]
);

let blocks = reader.get_block_ids_range(.."a");
let blocks = reader.get_block_ids_range::<_, &str, _>(.."a", ..);
assert_eq!(blocks, vec![block_id_0]);

let blocks = reader.get_block_ids_range(..="a");
let blocks = reader.get_block_ids_range::<_, &str, _>(..="a", ..);
assert_eq!(blocks, vec![block_id_0, block_id_1, block_id_2]);

let blocks = reader.get_block_ids_range("b"..="c");
let blocks = reader.get_block_ids_range::<_, &str, _>("b"..="c", ..);
assert_eq!(blocks, vec![block_id_2, block_id_3, block_id_4, block_id_5]);

let blocks = reader.get_block_ids_range("c"..);
let blocks = reader.get_block_ids_range::<_, &str, _>("c".., ..);
assert_eq!(blocks, vec![block_id_4, block_id_5, block_id_6]);
}

#[test]
fn test_get_block_ids_range_with_key_filter() {
// Three blocks, all under prefix "":
// block_0: [Start, ("", 100))
// block_1: [("", 100), ("", 200))
// block_2: [("", 200), ∞)
let block_id_0 = uuid::Uuid::new_v4();
let writer = SparseIndexWriter::new(block_id_0);
writer.set_count(block_id_0, 100).expect("set count");

let block_id_1 = uuid::Uuid::new_v4();
writer
.add_block(CompositeKey::new("".to_string(), 100u32), block_id_1)
.expect("add block");
writer.set_count(block_id_1, 100).expect("set count");

let block_id_2 = uuid::Uuid::new_v4();
writer
.add_block(CompositeKey::new("".to_string(), 200u32), block_id_2)
.expect("add block");
writer.set_count(block_id_2, 100).expect("set count");

let reader = writer.to_reader().expect("to reader");

// Unbounded key range — all blocks.
let blocks = reader.get_block_ids_range::<_, u32, _>(""..="", ..);
assert_eq!(blocks, vec![block_id_0, block_id_1, block_id_2]);

// Key 150 — only block_1 (key in [100, 200)).
let blocks = reader.get_block_ids_range(""..="", 150u32..=150u32);
assert_eq!(blocks, vec![block_id_1]);

// Key 50 — only block_0 ([Start, 100) contains key 50).
let blocks = reader.get_block_ids_range(""..="", 50u32..=50u32);
assert_eq!(blocks, vec![block_id_0]);

// Key 250 — only block_2 ([200, ∞) contains key 250).
let blocks = reader.get_block_ids_range(""..="", 250u32..=250u32);
assert_eq!(blocks, vec![block_id_2]);

// Key 99..=100 — spans block_0 and block_1.
let blocks = reader.get_block_ids_range(""..="", 99u32..=100u32);
assert_eq!(blocks, vec![block_id_0, block_id_1]);

// Key 0..=999 — all blocks.
let blocks = reader.get_block_ids_range(""..="", 0u32..=999u32);
assert_eq!(blocks, vec![block_id_0, block_id_1, block_id_2]);

// Key exactly at block boundary — 100 is block_1's start.
let blocks = reader.get_block_ids_range(""..="", 100u32..=100u32);
assert_eq!(blocks, vec![block_id_1]);

// Key exactly at block boundary — 200 is block_2's start.
let blocks = reader.get_block_ids_range(""..="", 200u32..=200u32);
assert_eq!(blocks, vec![block_id_2]);

// Key 0 — minimum key, in first block (Start delimiter).
let blocks = reader.get_block_ids_range(""..="", 0u32..=0u32);
assert_eq!(blocks, vec![block_id_0]);

// Excluded key start — falls to slow path (first/last not eliminated).
let blocks =
reader.get_block_ids_range(""..="", (Bound::Excluded(100u32), Bound::Included(150u32)));
assert_eq!(blocks, vec![block_id_0, block_id_1, block_id_2]);
}

#[test]
fn test_serde() {
let ids = [uuid::Uuid::new_v4(), uuid::Uuid::new_v4()];
Expand Down
Loading