From b542b3617a2bbbeeae6c4f9e10a24295814328ca Mon Sep 17 00:00:00 2001 From: Max Isom Date: Thu, 10 Oct 2024 16:45:46 -0700 Subject: [PATCH 1/7] [ENH]: replace `get_*` methods on Arrow blocks with `get_range()` --- rust/blockstore/src/arrow/block/types.rs | 155 +++++++++-------------- rust/blockstore/src/arrow/blockfile.rs | 14 +- 2 files changed, 70 insertions(+), 99 deletions(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 83d380449be..01ed7341b57 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering::{Equal, Greater, Less}; use std::collections::HashMap; use std::io::SeekFrom; +use std::ops::{Bound, RangeBounds}; use crate::arrow::types::{ArrowReadableKey, ArrowReadableValue}; use arrow::array::ArrayData; @@ -214,32 +215,6 @@ impl Block { ) } - #[inline] - fn scan_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - prefix: &str, - range: impl Iterator, - ) -> Vec<(K, V)> { - let prefix_array = self - .data - .column(0) - .as_any() - .downcast_ref::() - .expect("The prefix array should be a string arrary."); - let mut result = Vec::new(); - for index in range { - if prefix_array.value(index) == prefix { - result.push(( - K::get(self.data.column(1), index), - V::get(self.data.column(2), index), - )); - } else { - break; - } - } - result - } - /* ===== Block Queries ===== */ @@ -260,80 +235,72 @@ impl Block { } } - /// Get all the values for a given prefix in the block + /// Get all the values for a given prefix & key range in the block /// ### Panics /// - If the underlying data types are not the same as the types specified in the function signature - pub fn get_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( + /// - If at least one end of the prefix range is excluded (currently unsupported) + pub fn get_range< + 'prefix, + 'me, + K: ArrowReadableKey<'me>, + V: ArrowReadableValue<'me>, + PrefixRange, + KeyRange, + >( &'me self, - prefix: &str, - ) -> Vec<(K, V)> { - self.scan_prefix( - prefix, - self.binary_search_index(prefix, Option::<&K>::None)..self.len(), - ) - } - - /// Get all the values for a given prefix in the block where the key is greater than the given key - /// ### Panics - /// - If the underlying data types are not the same as the types specified in the function signature - pub fn get_gt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - prefix: &str, - key: K, - ) -> Vec<(K, V)> { - let index = self.binary_search_index(prefix, Some(&key)); - if self.match_prefix_key_at_index(prefix, &key, index) { - self.scan_prefix(prefix, index + 1..self.len()) - } else { - self.scan_prefix(prefix, index..self.len()) - } - } + prefix_range: PrefixRange, + key_range: KeyRange, + ) -> Vec<(K, V)> + where + PrefixRange: RangeBounds<&'prefix str>, + KeyRange: RangeBounds, + { + let start_index = match prefix_range.start_bound() { + Bound::Included(prefix) => match key_range.start_bound() { + Bound::Included(key) => self.binary_search_index(prefix, Some(key)), + Bound::Excluded(key) => { + let index = self.binary_search_index(prefix, Some(key)); + if self.match_prefix_key_at_index(prefix, key, index) { + index + 1 + } else { + index + } + } + Bound::Unbounded => self.binary_search_index::(prefix, None), + }, + Bound::Excluded(_) => { + unimplemented!("Excluded prefix range is not currently supported") + } + Bound::Unbounded => 0, + }; - /// Get all the values for a given prefix in the block where the key is greater than or equal to the given key - /// ### Panics - /// - If the underlying data types are not the same as the types specified in the function signature - pub fn get_gte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - prefix: &str, - key: K, - ) -> Vec<(K, V)> { - self.scan_prefix( - prefix, - self.binary_search_index(prefix, Some(&key))..self.len(), - ) - } + let end_index = match prefix_range.end_bound() { + Bound::Included(prefix) => match key_range.end_bound() { + Bound::Included(key) => { + let index = self.binary_search_index(prefix, Some(key)); + if self.match_prefix_key_at_index(prefix, key, index) { + index + 1 + } else { + index + } + } + Bound::Excluded(key) => self.binary_search_index(prefix, Some(key)), + Bound::Unbounded => self.len(), + }, + Bound::Excluded(_) => { + unimplemented!("Excluded prefix range is not currently supported") + } + Bound::Unbounded => self.len(), + }; - /// Get all the values for a given prefix in the block where the key is less than the given key - /// ### Panics - /// - If the underlying data types are not the same as the types specified in the function signature - pub fn get_lt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - prefix: &str, - key: K, - ) -> Vec<(K, V)> { - let mut result = self.scan_prefix( - prefix, - (0..self.binary_search_index(prefix, Some(&key))).rev(), - ); - result.reverse(); - result - } + let mut result = Vec::new(); - /// Get all the values for a given prefix in the block where the key is less than or equal to the given key - /// ### Panics - /// - If the underlying data types are not the same as the types specified in the function signature - pub fn get_lte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - prefix: &str, - key: K, - ) -> Vec<(K, V)> { - let index = self.binary_search_index(prefix, Some(&key)); - let mut result = if self.match_prefix_key_at_index(prefix, &key, index) { - self.scan_prefix(prefix, (0..=index).rev()) - } else { - self.scan_prefix(prefix, (0..index).rev()) - }; - result.reverse(); + for index in start_index..end_index { + result.push(( + K::get(self.data.column(1), index), + V::get(self.data.column(2), index), + )); + } result } diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index a11b6703128..faf1a2e8d3e 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -17,6 +17,7 @@ use futures::future::join_all; use parking_lot::Mutex; use std::collections::HashSet; use std::mem::transmute; +use std::ops::Bound; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use uuid::Uuid; @@ -509,7 +510,10 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me return Err(Box::new(ArrowBlockfileError::BlockNotFound)); } }; - result.extend(block.get_gt(prefix, key.clone())); + result.extend(block.get_range( + prefix..=prefix, + (Bound::Excluded(key.clone()), Bound::Unbounded), + )); } Ok(result) } @@ -544,7 +548,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me return Err(Box::new(ArrowBlockfileError::BlockNotFound)); } }; - result.extend(block.get_lt(prefix, key.clone())); + result.extend(block.get_range(prefix..=prefix, ..key.clone())); } Ok(result) } @@ -579,7 +583,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me return Err(Box::new(ArrowBlockfileError::BlockNotFound)); } }; - result.extend(block.get_gte(prefix, key.clone())); + result.extend(block.get_range(prefix..=prefix, key.clone()..)); } Ok(result) } @@ -614,7 +618,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me return Err(Box::new(ArrowBlockfileError::BlockNotFound)); } }; - result.extend(block.get_lte(prefix, key.clone())); + result.extend(block.get_range(prefix..=prefix, ..=key.clone())); } Ok(result) } @@ -647,7 +651,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me } }; - result.extend(block.get_prefix(prefix)); + result.extend(block.get_range(prefix..=prefix, ..)); } Ok(result) } From 5b5b11c04f6a3ea38719415e65c6384edb3d3a40 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Thu, 10 Oct 2024 16:56:06 -0700 Subject: [PATCH 2/7] Propagate --- rust/blockstore/src/arrow/blockfile.rs | 164 ++----------------------- rust/blockstore/src/types.rs | 23 +++- 2 files changed, 30 insertions(+), 157 deletions(-) diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index faf1a2e8d3e..5cfba2d8dd7 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -17,7 +17,7 @@ use futures::future::join_all; use parking_lot::Mutex; use std::collections::HashSet; use std::mem::transmute; -use std::ops::Bound; +use std::ops::RangeBounds; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use uuid::Uuid; @@ -477,130 +477,22 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me } } - /// Returns all arrow records whose key > supplied key. - pub(crate) async fn get_gt( + // Returns all Arrow records in the specified range. + pub(crate) async fn get_range<'prefix, PrefixRange, KeyRange>( &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys > key from sparse index for this prefix. - let block_ids = self.root.sparse_index.get_block_ids_range( - prefix..=prefix, - ( - std::ops::Bound::Excluded(key.clone()), - std::ops::Bound::Unbounded, - ), - ); - let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys > key. - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - result.extend(block.get_range( - prefix..=prefix, - (Bound::Excluded(key.clone()), Bound::Unbounded), - )); - } - Ok(result) - } - - /// Returns all arrow records whose key < supplied key. - pub(crate) async fn get_lt( - &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys < key from sparse index. - let block_ids = self - .root - .sparse_index - .get_block_ids_range(prefix..=prefix, ..key.clone()); - let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys < key. - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - result.extend(block.get_range(prefix..=prefix, ..key.clone())); - } - Ok(result) - } - - /// Returns all arrow records whose key >= supplied key. - pub(crate) async fn get_gte( - &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys >= key from sparse index. + prefix_range: PrefixRange, + key_range: KeyRange, + ) -> Result, Box> + where + PrefixRange: RangeBounds<&'prefix str> + Clone, + KeyRange: RangeBounds + Clone, + { let block_ids = self .root .sparse_index - .get_block_ids_range(prefix..=prefix, key.clone()..); - let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys >= key. - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - result.extend(block.get_range(prefix..=prefix, key.clone()..)); - } - Ok(result) - } + .get_block_ids_range(prefix_range.clone(), key_range.clone()); - /// Returns all arrow records whose key <= supplied key. - pub(crate) async fn get_lte( - &'me self, - prefix: &str, - key: K, - ) -> Result, Box> { - // Get all block ids that contain keys <= key from sparse index. - let block_ids = self - .root - .sparse_index - .get_block_ids_range(prefix..=prefix, ..=key.clone()); let mut result: Vec<(K, V)> = vec![]; - // Read all the blocks individually to get keys <= key. for block_id in block_ids { let block_opt = match self.get_block(block_id).await { Ok(Some(block)) => Some(block), @@ -618,41 +510,9 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me return Err(Box::new(ArrowBlockfileError::BlockNotFound)); } }; - result.extend(block.get_range(prefix..=prefix, ..=key.clone())); + result.extend(block.get_range(prefix_range.clone(), key_range.clone())); } - Ok(result) - } - /// Returns all arrow records whose prefix is same as supplied prefix. - pub(crate) async fn get_by_prefix( - &'me self, - prefix: &str, - ) -> Result, Box> { - let block_ids = self - .root - .sparse_index - .get_block_ids_range::(prefix..=prefix, ..); - let mut result: Vec<(K, V)> = vec![]; - for block_id in block_ids { - let block_opt = match self.get_block(block_id).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - - let block = match block_opt { - Some(b) => b, - None => { - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - }; - - result.extend(block.get_range(prefix..=prefix, ..)); - } Ok(result) } diff --git a/rust/blockstore/src/types.rs b/rust/blockstore/src/types.rs index 78025b2486a..9972b1c8b41 100644 --- a/rust/blockstore/src/types.rs +++ b/rust/blockstore/src/types.rs @@ -13,6 +13,7 @@ use chroma_types::DataRecord; use roaring::RoaringBitmap; use std::fmt::{Debug, Display}; use std::mem::size_of; +use std::ops::Bound; use thiserror::Error; #[derive(Debug, Error)] @@ -282,7 +283,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_by_prefix(prefix), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_by_prefix(prefix).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, ..).await + } } } @@ -293,7 +296,11 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_gt(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_gt(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader + .get_range(prefix..=prefix, (Bound::Excluded(key), Bound::Unbounded)) + .await + } } } @@ -304,7 +311,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_lt(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_lt(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, ..key).await + } } } @@ -315,7 +324,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_gte(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_gte(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, key..).await + } } } @@ -326,7 +337,9 @@ impl< ) -> Result, Box> { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.get_lte(prefix, key), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_lte(prefix, key).await, + BlockfileReader::ArrowBlockfileReader(reader) => { + reader.get_range(prefix..=prefix, ..=key).await + } } } From b7759f4ec728fd027c226df8e93286f75336cac8 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Fri, 11 Oct 2024 15:47:43 -0700 Subject: [PATCH 3/7] Fix bug --- rust/blockstore/src/arrow/block/types.rs | 56 +++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 01ed7341b57..f98906b617f 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -119,6 +119,57 @@ impl Block { delta } + /// Binary search the block to find the last index of the specified prefix. + /// Returns None if prefix does not exist in the block. + /// [`std::slice::partition_point`]: std::slice::partition_point + #[inline] + fn binary_search_last_index(&self, prefix: &str) -> Option { + let mut size = self.len(); + if size == 0 { + return None; + } + + let prefix_array = self + .data + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let mut base = self.len() - 1; + + // This loop intentionally doesn't have an early exit if the comparison + // returns Equal. We want the number of loop iterations to depend *only* + // on the size of the input slice so that the CPU can reliably predict + // the loop count. + while size > 1 { + let half = size / 2; + let mid = base - half; + + // SAFETY: the call is made safe by the following inconstants: + // - `mid >= 0`: by definition + // - `mid < size`: `mid = size / 2 - size / 4 - size / 8 ...` + let cmp = prefix_array.value(mid).cmp(prefix); + + base = if cmp == Greater { mid } else { base }; + size -= half; + } + + // SAFETY: `base` is always in [0, size) because `base < size` by init. + // `base` should be the last index where the element matches the target prefix, + // or 0 if the first element is already larger than the target prefix. + match prefix_array.value(base).cmp(prefix) { + Less => None, + Equal => Some(base), + Greater => { + if prefix_array.value(base - 1) == prefix { + Some(base - 1) + } else { + None + } + } + } + } + /// Binary search the blockfile to find the partition point of the specified prefix and key. /// The implementation is based on [`std::slice::partition_point`]. /// @@ -285,7 +336,10 @@ impl Block { } } Bound::Excluded(key) => self.binary_search_index(prefix, Some(key)), - Bound::Unbounded => self.len(), + Bound::Unbounded => match self.binary_search_last_index(prefix) { + Some(last_index_of_prefix) => last_index_of_prefix + 1, // (add 1 because end_index is exclusive below) + None => start_index, // prefix does not exist in the block so we shouldn't return anything + }, }, Bound::Excluded(_) => { unimplemented!("Excluded prefix range is not currently supported") From 93685ea4289f8e6ed1cb21840515b40278cf3132 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Wed, 16 Oct 2024 11:10:37 -0700 Subject: [PATCH 4/7] Fix test --- rust/blockstore/src/arrow/block/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index f98906b617f..ba92f3663ed 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -161,7 +161,7 @@ impl Block { Less => None, Equal => Some(base), Greater => { - if prefix_array.value(base - 1) == prefix { + if base > 0 && prefix_array.value(base - 1) == prefix { Some(base - 1) } else { None From fd992068b0c015f10217465a1f7ba82011c267dc Mon Sep 17 00:00:00 2001 From: Max Isom Date: Fri, 18 Oct 2024 10:08:32 -0700 Subject: [PATCH 5/7] Update comments --- rust/blockstore/src/arrow/block/types.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index ba92f3663ed..6e487146913 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -121,7 +121,8 @@ impl Block { /// Binary search the block to find the last index of the specified prefix. /// Returns None if prefix does not exist in the block. - /// [`std::slice::partition_point`]: std::slice::partition_point + /// + /// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770 #[inline] fn binary_search_last_index(&self, prefix: &str) -> Option { let mut size = self.len(); @@ -146,8 +147,8 @@ impl Block { let mid = base - half; // SAFETY: the call is made safe by the following inconstants: - // - `mid >= 0`: by definition - // - `mid < size`: `mid = size / 2 - size / 4 - size / 8 ...` + // - `mid < size`: by definition + // - `mid >= 0`: `mid = size - 1 - size / 2 - size / 4 ...` let cmp = prefix_array.value(mid).cmp(prefix); base = if cmp == Greater { mid } else { base }; @@ -171,7 +172,6 @@ impl Block { } /// Binary search the blockfile to find the partition point of the specified prefix and key. - /// The implementation is based on [`std::slice::partition_point`]. /// /// `(prefix, key)` serves as the search key, and it is sorted in ascending order. /// The partition predicate is defined by: `|x| x < (prefix, key)`. @@ -179,7 +179,7 @@ impl Block { /// The code is a result of inlining this predicate in [`std::slice::partition_point`]. /// If the key is unspecified (i.e. `None`), we find the first index of the prefix. /// - /// [`std::slice::partition_point`]: std::slice::partition_point + /// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770 #[inline] fn binary_search_index<'me, K: ArrowReadableKey<'me>>( &'me self, From 95fb27b129901e47f9d392e0e6123db4bd1437b1 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Thu, 24 Oct 2024 15:13:36 -0700 Subject: [PATCH 6/7] Use common binary_search_by() method --- rust/blockstore/src/arrow/block/types.rs | 217 +++++++++++++---------- 1 file changed, 124 insertions(+), 93 deletions(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 6e487146913..eef18065efe 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering::{Equal, Greater, Less}; +use std::cmp::Ordering; use std::collections::HashMap; use std::io::SeekFrom; use std::ops::{Bound, RangeBounds}; @@ -119,15 +119,35 @@ impl Block { delta } - /// Binary search the block to find the last index of the specified prefix. - /// Returns None if prefix does not exist in the block. + /// Binary searches this slice with a comparator function. /// - /// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770 + /// The comparator function should return an order code that indicates + /// whether its argument is `Less`, `Equal` or `Greater` the desired + /// target. + /// If the slice is not sorted or if the comparator function does not + /// implement an order consistent with the sort order of the underlying + /// slice, the returned result is unspecified and meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. The index is chosen + /// deterministically, but is subject to change in future versions of Rust. + /// If the value is not found then [`Result::Err`] is returned, containing + /// the index where a matching element could be inserted while maintaining + /// sorted order. + /// + /// Based on std::slice::binary_search_by with minimal modifications (https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770). #[inline] - fn binary_search_last_index(&self, prefix: &str) -> Option { + fn binary_search_by<'me, K: ArrowReadableKey<'me>, F>( + &'me self, + mut f: F, + ) -> Result + where + F: FnMut((&'me str, K)) -> Ordering, + { let mut size = self.len(); if size == 0 { - return None; + return Err(0); } let prefix_array = self @@ -136,7 +156,7 @@ impl Block { .as_any() .downcast_ref::() .unwrap(); - let mut base = self.len() - 1; + let mut base = 0; // This loop intentionally doesn't have an early exit if the comparison // returns Equal. We want the number of loop iterations to depend *only* @@ -144,51 +164,58 @@ impl Block { // the loop count. while size > 1 { let half = size / 2; - let mid = base - half; + let mid = base + half; // SAFETY: the call is made safe by the following inconstants: - // - `mid < size`: by definition - // - `mid >= 0`: `mid = size - 1 - size / 2 - size / 4 ...` - let cmp = prefix_array.value(mid).cmp(prefix); - - base = if cmp == Greater { mid } else { base }; + // - `mid >= 0`: by definition + // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...` + let prefix = prefix_array.value(mid); + let key = K::get(self.data.column(1), mid); + let cmp = f((prefix, key)); + + base = if cmp == Ordering::Greater { base } else { mid }; + + // This is imprecise in the case where `size` is odd and the + // comparison returns Greater: the mid element still gets included + // by `size` even though it's known to be larger than the element + // being searched for. + // + // This is fine though: we gain more performance by keeping the + // loop iteration count invariant (and thus predictable) than we + // lose from considering one additional element. size -= half; } - // SAFETY: `base` is always in [0, size) because `base < size` by init. - // `base` should be the last index where the element matches the target prefix, - // or 0 if the first element is already larger than the target prefix. - match prefix_array.value(base).cmp(prefix) { - Less => None, - Equal => Some(base), - Greater => { - if base > 0 && prefix_array.value(base - 1) == prefix { - Some(base - 1) - } else { - None - } - } + // SAFETY: base is always in [0, size) because base <= mid. + let prefix = prefix_array.value(base); + let key = K::get(self.data.column(1), base); + let cmp = f((prefix, key)); + if cmp == Ordering::Equal { + Ok(base) + } else { + let result = base + (cmp == Ordering::Less) as usize; + Err(result) } } - /// Binary search the blockfile to find the partition point of the specified prefix and key. - /// - /// `(prefix, key)` serves as the search key, and it is sorted in ascending order. - /// The partition predicate is defined by: `|x| x < (prefix, key)`. - /// The partition point is the first index where the partition precidate evaluates to `false` - /// The code is a result of inlining this predicate in [`std::slice::partition_point`]. - /// If the key is unspecified (i.e. `None`), we find the first index of the prefix. - /// - /// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770 + /// Returns the largest index where `prefixes[index] == prefix` or None if the provided prefix does not exist in the block. #[inline] - fn binary_search_index<'me, K: ArrowReadableKey<'me>>( + fn find_largest_index_of_prefix<'me, K: ArrowReadableKey<'me>>( &'me self, prefix: &str, - key: Option<&K>, - ) -> usize { - let mut size = self.len(); - if size == 0 { - return 0; + ) -> Option { + // By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that is greater than the prefix. If no element is greater, it returns the length of the array. + let result = self + .binary_search_by::(|(p, _)| match p.cmp(prefix) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Less, + Ordering::Greater => Ordering::Greater, + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal."); + + if result == 0 { + // The first element is greater than the target prefix, so the target prefix does not exist in the block. + return None; } let prefix_array = self @@ -197,49 +224,48 @@ impl Block { .as_any() .downcast_ref::() .unwrap(); - let mut base = 0; - // This loop intentionally doesn't have an early exit if the comparison - // returns Equal. We want the number of loop iterations to depend *only* - // on the size of the input slice so that the CPU can reliably predict - // the loop count. - while size > 1 { - let half = size / 2; - let mid = base + half; - - // SAFETY: the call is made safe by the following inconstants: - // - `mid >= 0`: by definition - // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...` - let mut cmp = prefix_array.value(mid).cmp(prefix); - - // Continue to compare the key if prefix matches - if let (Equal, Some(k)) = (cmp, key) { - cmp = K::get(self.data.column(1), mid) - // Key type do not have total order because of floating point values - // But in our case NaN should not be allowed so we should always have total order - .partial_cmp(k) - .expect("Array values should be comparable."); - } - - base = if cmp == Less { mid } else { base }; - size -= half; + // `result` is the first index where the prefix is larger than the input (or the length of the array) so we want one element before this. + match prefix_array.value(result - 1).cmp(prefix) { + // We're at the end of the array, so the prefix does not exist in the block (all values are less than the prefix) + Ordering::Less => None, + // The prefix exists + Ordering::Equal => Some(result - 1), + // This is impossible + Ordering::Greater => None, } + } - // SAFETY: `base` is always in [0, size) because `base <= mid`. - // `base` should be the last index where the element is smaller than the target, - // or 0 if the first element is already larger than the target. - match prefix_array.value(base).cmp(prefix) { - Less => base + 1, - Equal => match key { - // Key type do not have total order because of floating point values - // But in our case NaN should not be allowed so we should always have total order - Some(k) => match K::get(self.data.column(1), base).partial_cmp(k) { - Some(Less) => base + 1, - _ => base, - }, - None => base, - }, - Greater => base, + /// Finds the partition point of the prefix and key. + /// Returns the index of the first element that matches the target prefix and key. If no element matches, returns the index at which the target prefix and key could be inserted to maintain sorted order. + #[inline] + fn get_key_prefix_partition_point<'me, K: ArrowReadableKey<'me>>( + &'me self, + prefix: &str, + key: Option<&K>, + ) -> usize { + // By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that matches the target prefix and key. If no element matches, it returns the index at which the target prefix and key could be inserted to maintain sorted order. + if let Some(key) = key { + self.binary_search_by::(|(p, k)| { + match p.cmp(prefix).then_with(|| { + k.partial_cmp(key) + // The key type does not have a total order because of floating point values. + // But in our case NaN is not allowed, so we should always have total order. + .expect("Array values should be comparable.") + }) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Greater, + Ordering::Greater => Ordering::Greater, + } + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal.") + } else { + self.binary_search_by::(|(p, _)| match p.cmp(prefix) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Greater, + Ordering::Greater => Ordering::Greater, + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal.") } } @@ -262,7 +288,7 @@ impl Block { prefix_array.value(index).cmp(prefix), K::get(self.data.column(1), index).partial_cmp(key), ), - (Equal, Some(Equal)) + (Ordering::Equal, Some(Ordering::Equal)) ) } @@ -278,11 +304,16 @@ impl Block { prefix: &str, key: K, ) -> Option { - let index = self.binary_search_index(prefix, Some(&key)); - if self.match_prefix_key_at_index(prefix, &key, index) { - Some(V::get(self.data.column(2), index)) - } else { - None + match self.binary_search_by::(|(p, k)| { + p.cmp(prefix).then_with(|| { + k.partial_cmp(&key) + // The key type does not have a total order because of floating point values. + // But in our case NaN is not allowed, so we should always have total order. + .expect("Array values should be comparable.") + }) + }) { + Ok(index) => Some(V::get(self.data.column(2), index)), + Err(_) => None, } } @@ -308,16 +339,16 @@ impl Block { { let start_index = match prefix_range.start_bound() { Bound::Included(prefix) => match key_range.start_bound() { - Bound::Included(key) => self.binary_search_index(prefix, Some(key)), + Bound::Included(key) => self.get_key_prefix_partition_point(prefix, Some(key)), Bound::Excluded(key) => { - let index = self.binary_search_index(prefix, Some(key)); + let index = self.get_key_prefix_partition_point(prefix, Some(key)); if self.match_prefix_key_at_index(prefix, key, index) { index + 1 } else { index } } - Bound::Unbounded => self.binary_search_index::(prefix, None), + Bound::Unbounded => self.get_key_prefix_partition_point::(prefix, None), }, Bound::Excluded(_) => { unimplemented!("Excluded prefix range is not currently supported") @@ -328,15 +359,15 @@ impl Block { let end_index = match prefix_range.end_bound() { Bound::Included(prefix) => match key_range.end_bound() { Bound::Included(key) => { - let index = self.binary_search_index(prefix, Some(key)); + let index = self.get_key_prefix_partition_point(prefix, Some(key)); if self.match_prefix_key_at_index(prefix, key, index) { index + 1 } else { index } } - Bound::Excluded(key) => self.binary_search_index(prefix, Some(key)), - Bound::Unbounded => match self.binary_search_last_index(prefix) { + Bound::Excluded(key) => self.get_key_prefix_partition_point(prefix, Some(key)), + Bound::Unbounded => match self.find_largest_index_of_prefix::(prefix) { Some(last_index_of_prefix) => last_index_of_prefix + 1, // (add 1 because end_index is exclusive below) None => start_index, // prefix does not exist in the block so we shouldn't return anything }, From a6a50e8d9c91ccd6f8a201239c5af3796491cbcb Mon Sep 17 00:00:00 2001 From: Max Isom Date: Fri, 25 Oct 2024 16:09:25 -0700 Subject: [PATCH 7/7] Decompose --- rust/blockstore/src/arrow/block/types.rs | 81 ++++++++++++++++-------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index eef18065efe..f22ba3515f0 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -214,7 +214,7 @@ impl Block { .expect_err("Never returns Ok because the comparator never evaluates to Equal."); if result == 0 { - // The first element is greater than the target prefix, so the target prefix does not exist in the block. + // The first element is greater than the target prefix, so the target prefix does not exist in the block. (Or the block is empty.) return None; } @@ -236,37 +236,61 @@ impl Block { } } + /// Returns the smallest index where `prefixes[index] == prefix` or None if the provided prefix does not exist in the block. + #[inline] + fn find_smallest_index_of_prefix<'me, K: ArrowReadableKey<'me>>( + &'me self, + prefix: &str, + ) -> Option { + let result = self + .binary_search_by::(|(p, _)| match p.cmp(prefix) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Greater, + Ordering::Greater => Ordering::Greater, + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal."); + + let prefix_array = self + .data + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + if result == self.len() { + // The target prefix is greater than all elements in the block. + return None; + } + + match prefix_array.value(result).cmp(prefix) { + Ordering::Greater => None, + Ordering::Less => None, + Ordering::Equal => Some(result), + } + } + /// Finds the partition point of the prefix and key. /// Returns the index of the first element that matches the target prefix and key. If no element matches, returns the index at which the target prefix and key could be inserted to maintain sorted order. #[inline] - fn get_key_prefix_partition_point<'me, K: ArrowReadableKey<'me>>( + fn binary_search_prefix_key<'me, K: ArrowReadableKey<'me>>( &'me self, prefix: &str, - key: Option<&K>, + key: &K, ) -> usize { // By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that matches the target prefix and key. If no element matches, it returns the index at which the target prefix and key could be inserted to maintain sorted order. - if let Some(key) = key { - self.binary_search_by::(|(p, k)| { - match p.cmp(prefix).then_with(|| { - k.partial_cmp(key) - // The key type does not have a total order because of floating point values. - // But in our case NaN is not allowed, so we should always have total order. - .expect("Array values should be comparable.") - }) { - Ordering::Less => Ordering::Less, - Ordering::Equal => Ordering::Greater, - Ordering::Greater => Ordering::Greater, - } - }) - .expect_err("Never returns Ok because the comparator never evaluates to Equal.") - } else { - self.binary_search_by::(|(p, _)| match p.cmp(prefix) { + self.binary_search_by::(|(p, k)| { + match p.cmp(prefix).then_with(|| { + k.partial_cmp(key) + // The key type does not have a total order because of floating point values. + // But in our case NaN is not allowed, so we should always have total order. + .expect("Array values should be comparable.") + }) { Ordering::Less => Ordering::Less, Ordering::Equal => Ordering::Greater, Ordering::Greater => Ordering::Greater, - }) - .expect_err("Never returns Ok because the comparator never evaluates to Equal.") - } + } + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal.") } #[inline] @@ -339,16 +363,19 @@ impl Block { { let start_index = match prefix_range.start_bound() { Bound::Included(prefix) => match key_range.start_bound() { - Bound::Included(key) => self.get_key_prefix_partition_point(prefix, Some(key)), + Bound::Included(key) => self.binary_search_prefix_key(prefix, key), Bound::Excluded(key) => { - let index = self.get_key_prefix_partition_point(prefix, Some(key)); + let index = self.binary_search_prefix_key(prefix, key); if self.match_prefix_key_at_index(prefix, key, index) { index + 1 } else { index } } - Bound::Unbounded => self.get_key_prefix_partition_point::(prefix, None), + Bound::Unbounded => match self.find_smallest_index_of_prefix::(prefix) { + Some(first_index_of_prefix) => first_index_of_prefix, + None => self.len(), // prefix does not exist in the block so we shouldn't return anything + }, }, Bound::Excluded(_) => { unimplemented!("Excluded prefix range is not currently supported") @@ -359,14 +386,14 @@ impl Block { let end_index = match prefix_range.end_bound() { Bound::Included(prefix) => match key_range.end_bound() { Bound::Included(key) => { - let index = self.get_key_prefix_partition_point(prefix, Some(key)); + let index = self.binary_search_prefix_key(prefix, key); if self.match_prefix_key_at_index(prefix, key, index) { index + 1 } else { index } } - Bound::Excluded(key) => self.get_key_prefix_partition_point(prefix, Some(key)), + Bound::Excluded(key) => self.binary_search_prefix_key::(prefix, key), Bound::Unbounded => match self.find_largest_index_of_prefix::(prefix) { Some(last_index_of_prefix) => last_index_of_prefix + 1, // (add 1 because end_index is exclusive below) None => start_index, // prefix does not exist in the block so we shouldn't return anything