Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Introduce spann segment reader #3212

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions rust/blockstore/src/memory/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,39 @@ impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> {
}
}

impl<'referred_data> Readable<'referred_data> for SpannPostingList<'referred_data> {
fn read_from_storage(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> Option<Self> {
todo!()
}

fn read_range_from_storage<'prefix, PrefixRange, KeyRange>(
_: PrefixRange,
_: KeyRange,
_: &'referred_data Storage,
) -> Vec<(&'referred_data CompositeKey, Self)>
where
PrefixRange: std::ops::RangeBounds<&'prefix str>,
KeyRange: std::ops::RangeBounds<KeyWrapper>,
{
todo!()
}

fn get_at_index(
_: &'referred_data Storage,
_: usize,
) -> Option<(&'referred_data CompositeKey, Self)> {
todo!()
}

fn count(_: &Storage) -> Result<usize, Box<dyn ChromaError>> {
todo!()
}

fn contains(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> bool {
todo!()
}
}

#[derive(Clone)]
pub struct StorageBuilder {
bool_storage: Arc<RwLock<Option<BTreeMap<CompositeKey, bool>>>>,
Expand Down
6 changes: 6 additions & 0 deletions rust/blockstore/src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ impl Value for &DataRecord<'_> {
}
}

impl Value for SpannPostingList<'_> {
fn get_size(&self) -> usize {
self.compute_size()
}
}

impl Value for &SpannPostingList<'_> {
fn get_size(&self) -> usize {
self.compute_size()
Expand Down
148 changes: 138 additions & 10 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use chroma_blockstore::{
provider::{BlockfileProvider, CreateError, OpenError},
BlockfileFlusher, BlockfileWriter, BlockfileWriterOptions,
BlockfileFlusher, BlockfileReader, BlockfileWriter, BlockfileWriterOptions,
};
use chroma_distance::{normalize, DistanceFunction};
use chroma_error::{ChromaError, ErrorCodes};
Expand Down Expand Up @@ -1427,6 +1427,124 @@ impl SpannIndexFlusher {
}
}

#[derive(Error, Debug)]
pub enum SpannIndexReaderError {
#[error("Error creating/opening hnsw index")]
HnswIndexConstructionError,
#[error("Error creating/opening blockfile reader")]
BlockfileReaderConstructionError,
#[error("Spann index uninitialized")]
UninitializedIndex,
}

impl ChromaError for SpannIndexReaderError {
fn code(&self) -> ErrorCodes {
match self {
Self::HnswIndexConstructionError => ErrorCodes::Internal,
Self::BlockfileReaderConstructionError => ErrorCodes::Internal,
Self::UninitializedIndex => ErrorCodes::Internal,
}
}
}

#[derive(Clone)]
pub struct SpannIndexReader<'me> {
pub posting_lists: BlockfileReader<'me, u32, SpannPostingList<'me>>,
pub hnsw_index: HnswIndexRef,
pub versions_map: BlockfileReader<'me, u32, u32>,
}

impl<'me> SpannIndexReader<'me> {
async fn hnsw_index_from_id(
hnsw_provider: &HnswIndexProvider,
id: &IndexUuid,
cache_key: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
) -> Result<HnswIndexRef, SpannIndexReaderError> {
match hnsw_provider.get(id, cache_key).await {
Some(index) => Ok(index),
None => {
match hnsw_provider
.open(id, cache_key, dimensionality as i32, distance_function)
.await
{
Ok(index) => Ok(index),
Err(_) => Err(SpannIndexReaderError::HnswIndexConstructionError),
}
}
}
}

async fn posting_list_reader_from_id(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileReader<'me, u32, SpannPostingList<'me>>, SpannIndexReaderError> {
match blockfile_provider
.read::<u32, SpannPostingList<'me>>(blockfile_id)
.await
{
Ok(reader) => Ok(reader),
Err(_) => Err(SpannIndexReaderError::BlockfileReaderConstructionError),
}
}

async fn versions_map_reader_from_id(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileReader<'me, u32, u32>, SpannIndexReaderError> {
match blockfile_provider.read::<u32, u32>(blockfile_id).await {
Ok(reader) => Ok(reader),
Err(_) => Err(SpannIndexReaderError::BlockfileReaderConstructionError),
}
}

#[allow(clippy::too_many_arguments)]
pub async fn from_id(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split this into two

hnsw_id: Option<&IndexUuid>,
hnsw_provider: &HnswIndexProvider,
hnsw_cache_key: &CollectionUuid,
distance_function: DistanceFunction,
dimensionality: usize,
pl_blockfile_id: Option<&Uuid>,
versions_map_blockfile_id: Option<&Uuid>,
blockfile_provider: &BlockfileProvider,
) -> Result<SpannIndexReader<'me>, SpannIndexReaderError> {
let hnsw_reader = match hnsw_id {
Some(hnsw_id) => {
Self::hnsw_index_from_id(
hnsw_provider,
hnsw_id,
hnsw_cache_key,
distance_function,
dimensionality,
)
.await?
}
None => {
return Err(SpannIndexReaderError::UninitializedIndex);
}
};
let postings_list_reader = match pl_blockfile_id {
Some(pl_id) => Self::posting_list_reader_from_id(pl_id, blockfile_provider).await?,
None => return Err(SpannIndexReaderError::UninitializedIndex),
};

let versions_map_reader = match versions_map_blockfile_id {
Some(versions_id) => {
Self::versions_map_reader_from_id(versions_id, blockfile_provider).await?
}
None => return Err(SpannIndexReaderError::UninitializedIndex),
};

Ok(Self {
posting_lists: postings_list_reader,
hnsw_index: hnsw_reader,
versions_map: versions_map_reader,
})
}
}

#[cfg(test)]
mod tests {
use std::{f32::consts::PI, path::PathBuf};
Expand Down Expand Up @@ -1556,22 +1674,32 @@ mod tests {
{
// Posting list should have 100 points.
let pl_read_guard = writer.posting_list_writer.lock().await;
let pl = pl_read_guard
let pl1 = pl_read_guard
.get_owned::<u32, &SpannPostingList<'_>>("", emb_1_id)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 100);
assert_eq!(pl.1.len(), 100);
assert_eq!(pl.2.len(), 200);
let pl = pl_read_guard
let pl2 = pl_read_guard
.get_owned::<u32, &SpannPostingList<'_>>("", emb_2_id)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 1);
assert_eq!(pl.1.len(), 1);
assert_eq!(pl.2.len(), 2);
// Only two combinations possible.
if pl1.0.len() == 100 {
assert_eq!(pl1.1.len(), 100);
assert_eq!(pl1.2.len(), 200);
assert_eq!(pl2.0.len(), 1);
assert_eq!(pl2.1.len(), 1);
assert_eq!(pl2.2.len(), 2);
} else if pl2.0.len() == 100 {
assert_eq!(pl2.1.len(), 100);
assert_eq!(pl2.2.len(), 200);
assert_eq!(pl1.0.len(), 1);
assert_eq!(pl1.1.len(), 1);
assert_eq!(pl1.2.len(), 2);
} else {
panic!("Invalid posting list lengths");
}
}
// Next insert 99 points in the region of (1000.0, 1000.0)
for i in 102..=200 {
Expand Down Expand Up @@ -1911,7 +2039,7 @@ mod tests {
version_map_guard.versions_map.insert(100 + point as u32, 1);
}
}
// Delete 60 points each from the centers. Since merge_threshold is 40, this should
// Delete 60 points each from the centers. Since merge_threshold is 50, this should
// trigger a merge between the two centers.
for point in 1..=60 {
writer
Expand Down
1 change: 1 addition & 0 deletions rust/types/src/spann_posting_list.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[derive(Clone, Debug)]
pub struct SpannPostingList<'referred_data> {
pub doc_offset_ids: &'referred_data [u32],
pub doc_versions: &'referred_data [u32],
Expand Down
Loading
Loading