diff --git a/rust/blockstore/src/types/writer.rs b/rust/blockstore/src/types/writer.rs index db85d60e675..e251b745454 100644 --- a/rust/blockstore/src/types/writer.rs +++ b/rust/blockstore/src/types/writer.rs @@ -80,6 +80,23 @@ impl BlockfileWriter { } } + pub async fn get_owned< + K: Key + Into + ArrowWriteableKey, + V: Value + Writeable + ArrowWriteableValue, + >( + &self, + prefix: &str, + key: K, + ) -> Result, Box> { + match self { + BlockfileWriter::MemoryBlockfileWriter(_) => todo!(), + BlockfileWriter::ArrowUnorderedBlockfileWriter(writer) => { + writer.get_owned::(prefix, key).await + } + BlockfileWriter::ArrowOrderedBlockfileWriter(_) => todo!(), + } + } + pub fn id(&self) -> uuid::Uuid { match self { BlockfileWriter::MemoryBlockfileWriter(writer) => writer.id(), diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index c78c64c36c0..1810ee98190 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,65 +1,164 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + sync::{atomic::AtomicU32, Arc}, +}; -use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions}; -use chroma_distance::DistanceFunction; +use chroma_blockstore::{ + provider::{BlockfileProvider, CreateError, OpenError}, + BlockfileFlusher, BlockfileWriter, BlockfileWriterOptions, +}; +use chroma_distance::{normalize, DistanceFunction}; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{CollectionUuid, SpannPostingList}; +use chroma_types::CollectionUuid; +use chroma_types::SpannPostingList; use thiserror::Error; use uuid::Uuid; use crate::{ - hnsw_provider::{HnswIndexProvider, HnswIndexRef}, - IndexUuid, + hnsw_provider::{ + HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexRef, + }, + Index, IndexUuid, }; -// TODO(Sanket): Add locking structures as necessary. +pub struct VersionsMapInner { + pub versions_map: HashMap, +} + #[allow(dead_code)] +// Note: Fields of this struct are public for testing. pub struct SpannIndexWriter { // HNSW index and its provider for centroid search. - hnsw_index: HnswIndexRef, + pub hnsw_index: HnswIndexRef, hnsw_provider: HnswIndexProvider, + blockfile_provider: BlockfileProvider, // Posting list of the centroids. - // The blockfile also contains next id for the head. - posting_list_writer: BlockfileWriter, + // TODO(Sanket): For now the lock is very coarse grained. But this should + // be changed in future if perf is not satisfactory. + pub posting_list_writer: Arc>, + pub next_head_id: Arc, // Version number of each point. - versions_map: HashMap, + // TODO(Sanket): Finer grained locking for this map in future if perf is not satisfactory. + pub versions_map: Arc>, + pub distance_function: DistanceFunction, + pub dimensionality: usize, } +// TODO(Sanket): Can compose errors whenever downstream returns Box. #[derive(Error, Debug)] -pub enum SpannIndexWriterConstructionError { - #[error("HNSW index construction error")] - HnswIndexConstructionError, - #[error("Blockfile reader construction error")] - BlockfileReaderConstructionError, - #[error("Blockfile writer construction error")] - BlockfileWriterConstructionError, - #[error("Error loading version data from blockfile")] - BlockfileVersionDataLoadError, +pub enum SpannIndexWriterError { + #[error("Error forking hnsw index {0}")] + HnswIndexForkError(#[from] HnswIndexProviderForkError), + #[error("Error creating hnsw index {0}")] + HnswIndexCreateError(#[from] HnswIndexProviderCreateError), + #[error("Error opening reader for versions map blockfile {0}")] + VersionsMapOpenError(#[from] OpenError), + #[error("Error creating/forking postings list writer {0}")] + PostingsListCreateError(#[from] CreateError), + #[error("Error loading version data from blockfile {0}")] + VersionsMapDataLoadError(#[from] Box), + #[error("Error reading max offset id for heads")] + MaxHeadOffsetIdBlockfileGetError, + #[error("Error resizing hnsw index")] + HnswIndexResizeError, + #[error("Error adding to hnsw index")] + HnswIndexAddError, + #[error("Error searching from hnsw")] + HnswIndexSearchError, + #[error("Error adding posting list for a head")] + PostingListSetError, + #[error("Error getting the posting list for a head")] + PostingListGetError, + #[error("Did not find the version for head id")] + VersionNotFound, + #[error("Error committing postings list blockfile")] + PostingListCommitError, + #[error("Error creating blockfile writer for versions map")] + VersionsMapWriterCreateError, + #[error("Error writing data to versions map blockfile")] + VersionsMapSetError, + #[error("Error committing versions map blockfile")] + VersionsMapCommitError, + #[error("Error creating blockfile writer for max head id")] + MaxHeadIdWriterCreateError, + #[error("Error writing data to max head id blockfile")] + MaxHeadIdSetError, + #[error("Error committing max head id blockfile")] + MaxHeadIdCommitError, + #[error("Error committing hnsw index")] + HnswIndexCommitError, + #[error("Error flushing postings list blockfile")] + PostingListFlushError, + #[error("Error flushing versions map blockfile")] + VersionsMapFlushError, + #[error("Error flushing max head id blockfile")] + MaxHeadIdFlushError, + #[error("Error flushing hnsw index")] + HnswIndexFlushError, } -impl ChromaError for SpannIndexWriterConstructionError { +impl ChromaError for SpannIndexWriterError { fn code(&self) -> ErrorCodes { match self { - Self::HnswIndexConstructionError => ErrorCodes::Internal, - Self::BlockfileReaderConstructionError => ErrorCodes::Internal, - Self::BlockfileWriterConstructionError => ErrorCodes::Internal, - Self::BlockfileVersionDataLoadError => ErrorCodes::Internal, + Self::HnswIndexForkError(e) => e.code(), + Self::HnswIndexCreateError(e) => e.code(), + Self::VersionsMapOpenError(e) => e.code(), + Self::PostingsListCreateError(e) => e.code(), + Self::VersionsMapDataLoadError(e) => e.code(), + Self::MaxHeadOffsetIdBlockfileGetError => ErrorCodes::Internal, + Self::HnswIndexResizeError => ErrorCodes::Internal, + Self::HnswIndexAddError => ErrorCodes::Internal, + Self::PostingListSetError => ErrorCodes::Internal, + Self::HnswIndexSearchError => ErrorCodes::Internal, + Self::PostingListGetError => ErrorCodes::Internal, + Self::VersionNotFound => ErrorCodes::Internal, + Self::PostingListCommitError => ErrorCodes::Internal, + Self::VersionsMapSetError => ErrorCodes::Internal, + Self::VersionsMapCommitError => ErrorCodes::Internal, + Self::MaxHeadIdSetError => ErrorCodes::Internal, + Self::MaxHeadIdCommitError => ErrorCodes::Internal, + Self::HnswIndexCommitError => ErrorCodes::Internal, + Self::PostingListFlushError => ErrorCodes::Internal, + Self::VersionsMapFlushError => ErrorCodes::Internal, + Self::MaxHeadIdFlushError => ErrorCodes::Internal, + Self::HnswIndexFlushError => ErrorCodes::Internal, + Self::VersionsMapWriterCreateError => ErrorCodes::Internal, + Self::MaxHeadIdWriterCreateError => ErrorCodes::Internal, } } } +const MAX_HEAD_OFFSET_ID: &str = "max_head_offset_id"; + +// TODO(Sanket): Make these configurable. +#[allow(dead_code)] +const NUM_CENTROIDS_TO_SEARCH: u32 = 64; +#[allow(dead_code)] +const RNG_FACTOR: f32 = 1.0; +#[allow(dead_code)] +const SPLIT_THRESHOLD: usize = 100; + impl SpannIndexWriter { + #[allow(clippy::too_many_arguments)] pub fn new( hnsw_index: HnswIndexRef, hnsw_provider: HnswIndexProvider, + blockfile_provider: BlockfileProvider, posting_list_writer: BlockfileWriter, - versions_map: HashMap, + next_head_id: u32, + versions_map: VersionsMapInner, + distance_function: DistanceFunction, + dimensionality: usize, ) -> Self { SpannIndexWriter { hnsw_index, hnsw_provider, - posting_list_writer, - versions_map, + blockfile_provider, + posting_list_writer: Arc::new(tokio::sync::Mutex::new(posting_list_writer)), + next_head_id: Arc::new(AtomicU32::new(next_head_id)), + versions_map: Arc::new(parking_lot::RwLock::new(versions_map)), + distance_function, + dimensionality, } } @@ -69,13 +168,13 @@ impl SpannIndexWriter { collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, - ) -> Result { + ) -> Result { match hnsw_provider .fork(id, collection_id, dimensionality as i32, distance_function) .await { Ok(index) => Ok(index), - Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError), + Err(e) => Err(SpannIndexWriterError::HnswIndexForkError(*e)), } } @@ -87,7 +186,7 @@ impl SpannIndexWriter { m: usize, ef_construction: usize, ef_search: usize, - ) -> Result { + ) -> Result { match hnsw_provider .create( collection_id, @@ -100,37 +199,35 @@ impl SpannIndexWriter { .await { Ok(index) => Ok(index), - Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError), + Err(e) => Err(SpannIndexWriterError::HnswIndexCreateError(*e)), } } async fn load_versions_map( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, - ) -> Result, SpannIndexWriterConstructionError> { + ) -> Result { // Create a reader for the blockfile. Load all the data into the versions map. let mut versions_map = HashMap::new(); let reader = match blockfile_provider.read::(blockfile_id).await { Ok(reader) => reader, - Err(_) => { - return Err(SpannIndexWriterConstructionError::BlockfileReaderConstructionError) - } + Err(e) => return Err(SpannIndexWriterError::VersionsMapOpenError(*e)), }; // Load data using the reader. let versions_data = reader .get_range(.., ..) .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileVersionDataLoadError)?; + .map_err(SpannIndexWriterError::VersionsMapDataLoadError)?; versions_data.iter().for_each(|(key, value)| { versions_map.insert(*key, *value); }); - Ok(versions_map) + Ok(VersionsMapInner { versions_map }) } async fn fork_postings_list( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); bf_options = bf_options.fork(*blockfile_id); @@ -139,13 +236,13 @@ impl SpannIndexWriter { .await { Ok(writer) => Ok(writer), - Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), + Err(e) => Err(SpannIndexWriterError::PostingsListCreateError(*e)), } } async fn create_posting_list( blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); match blockfile_provider @@ -153,7 +250,7 @@ impl SpannIndexWriter { .await { Ok(writer) => Ok(writer), - Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), + Err(e) => Err(SpannIndexWriterError::PostingsListCreateError(*e)), } } @@ -163,6 +260,7 @@ impl SpannIndexWriter { hnsw_id: Option<&IndexUuid>, versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, + max_head_id_bf_id: Option<&Uuid>, m: Option, ef_construction: Option, ef_search: Option, @@ -170,7 +268,7 @@ impl SpannIndexWriter { distance_function: DistanceFunction, dimensionality: usize, blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { // Create the HNSW index. let hnsw_index = match hnsw_id { Some(hnsw_id) => { @@ -178,7 +276,7 @@ impl SpannIndexWriter { hnsw_provider, hnsw_id, collection_id, - distance_function, + distance_function.clone(), dimensionality, ) .await? @@ -187,7 +285,7 @@ impl SpannIndexWriter { Self::create_hnsw_index( hnsw_provider, collection_id, - distance_function, + distance_function.clone(), dimensionality, m.unwrap(), // Safe since caller should always provide this. ef_construction.unwrap(), // Safe since caller should always provide this. @@ -201,7 +299,9 @@ impl SpannIndexWriter { Some(versions_map_id) => { Self::load_versions_map(versions_map_id, blockfile_provider).await? } - None => HashMap::new(), + None => VersionsMapInner { + versions_map: HashMap::new(), + }, }; // Fork the posting list writer. let posting_list_writer = match posting_list_id { @@ -210,11 +310,351 @@ impl SpannIndexWriter { } None => Self::create_posting_list(blockfile_provider).await?, }; + + let max_head_id = match max_head_id_bf_id { + Some(max_head_id_bf_id) => { + let reader = blockfile_provider + .read::<&str, u32>(max_head_id_bf_id) + .await; + match reader { + Ok(reader) => reader + .get("", MAX_HEAD_OFFSET_ID) + .await + .map_err(|_| SpannIndexWriterError::MaxHeadOffsetIdBlockfileGetError)? + .unwrap(), + Err(_) => 1, + } + } + None => 1, + }; Ok(Self::new( hnsw_index, hnsw_provider.clone(), + blockfile_provider.clone(), posting_list_writer, + max_head_id, versions_map, + distance_function, + dimensionality, )) } + + fn add_versions_map(&self, id: u32) -> u32 { + // 0 means deleted. Version counting starts from 1. + let mut write_lock = self.versions_map.write(); + write_lock.versions_map.insert(id, 1); + *write_lock.versions_map.get(&id).unwrap() + } + + #[allow(dead_code)] + async fn rng_query( + &self, + query: &[f32], + ) -> Result<(Vec, Vec), SpannIndexWriterError> { + let ids; + let distances; + let mut embeddings: Vec> = vec![]; + { + let read_guard = self.hnsw_index.inner.read(); + let allowed_ids = vec![]; + let disallowed_ids = vec![]; + // Query is already normalized so no need to normalize again. + (ids, distances) = read_guard + .query( + query, + NUM_CENTROIDS_TO_SEARCH as usize, + &allowed_ids, + &disallowed_ids, + ) + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)?; + // Get the embeddings also for distance computation. + // Normalization is idempotent and since we write normalized embeddings + // to the hnsw index, we'll get the same embeddings after denormalization. + for id in ids.iter() { + let emb = read_guard + .get(*id) + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)? + .ok_or(SpannIndexWriterError::HnswIndexSearchError)?; + embeddings.push(emb); + } + } + // Apply the RNG rule to prune. + let mut res_ids = vec![]; + let mut res_distances = vec![]; + let mut res_embeddings: Vec<&Vec> = vec![]; + for (id, (distance, embedding)) in ids.iter().zip(distances.iter().zip(embeddings.iter())) { + let mut rng_accepted = true; + for nbr_embedding in res_embeddings.iter() { + // Embeddings are already normalized so no need to normalize again. + let dist = self + .distance_function + .distance(&embedding[..], &nbr_embedding[..]); + if RNG_FACTOR * dist <= *distance { + rng_accepted = false; + break; + } + } + if !rng_accepted { + continue; + } + res_ids.push(*id); + res_distances.push(*distance); + res_embeddings.push(embedding); + } + + Ok((res_ids, res_distances)) + } + + #[allow(dead_code)] + async fn append( + &self, + head_id: u32, + id: u32, + version: u32, + embedding: &[f32], + ) -> Result<(), SpannIndexWriterError> { + { + let write_guard = self.posting_list_writer.lock().await; + // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread + // deletes it. + let current_pl = write_guard + .get_owned::>("", head_id) + .await + .map_err(|_| SpannIndexWriterError::PostingListGetError)? + .ok_or(SpannIndexWriterError::PostingListGetError)?; + // Cleanup this posting list and append the new point to it. + // TODO(Sanket): There is an order in which we are acquiring locks here. Need + // to ensure the same order in the other places as well. + let mut updated_doc_offset_ids = vec![]; + let mut updated_versions = vec![]; + let mut updated_embeddings = vec![]; + { + let version_map_guard = self.versions_map.read(); + for (index, doc_version) in current_pl.1.iter().enumerate() { + let current_version = version_map_guard + .versions_map + .get(¤t_pl.0[index]) + .ok_or(SpannIndexWriterError::VersionNotFound)?; + // disregard if either deleted or on an older version. + if *current_version == 0 || doc_version < current_version { + continue; + } + updated_doc_offset_ids.push(current_pl.0[index]); + updated_versions.push(*doc_version); + // Slice. index*dimensionality to index*dimensionality + dimensionality + updated_embeddings.push( + ¤t_pl.2[index * self.dimensionality + ..index * self.dimensionality + self.dimensionality], + ); + } + } + // Add the new point. + updated_doc_offset_ids.push(id); + updated_versions.push(version); + updated_embeddings.push(embedding); + // TODO(Sanket): Trigger a split and reassign if the size exceeds threshold. + // Write the PL back to the blockfile and release the lock. + let posting_list = SpannPostingList { + doc_offset_ids: &updated_doc_offset_ids, + doc_versions: &updated_versions, + doc_embeddings: &updated_embeddings.concat(), + }; + // TODO(Sanket): Split if the size exceeds threshold. + write_guard + .set("", head_id, &posting_list) + .await + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; + } + Ok(()) + } + + #[allow(dead_code)] + async fn add_to_postings_list( + &self, + id: u32, + version: u32, + embeddings: &[f32], + ) -> Result<(), SpannIndexWriterError> { + let (ids, _) = self.rng_query(embeddings).await?; + // The only cases when this can happen is initially when no data exists in the + // index or if all the data that was added to the index was deleted later. + // In both the cases, in the worst case, it can happen that ids is empty + // for the first few points getting inserted concurrently by different threads. + // It's fine to create new centers for each of them since the number of such points + // will be very small and we can also run GC to merge them later if needed. + if ids.is_empty() { + let next_id = self + .next_head_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + // First add to postings list then to hnsw. This order is important + // to ensure that if and when the center is discoverable, it also exists + // in the postings list. Otherwise, it will be a dangling center. + { + let posting_list = SpannPostingList { + doc_offset_ids: &[id], + doc_versions: &[version], + doc_embeddings: embeddings, + }; + let write_guard = self.posting_list_writer.lock().await; + write_guard + .set("", next_id, &posting_list) + .await + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; + } + // Next add to hnsw. + // This shouldn't exceed the capacity since this will happen only for the first few points + // so no need to check and increase the capacity. + { + let write_guard = self.hnsw_index.inner.write(); + write_guard + .add(next_id as usize, embeddings) + .map_err(|_| SpannIndexWriterError::HnswIndexAddError)?; + } + return Ok(()); + } + // Otherwise add to the posting list of these arrays. + for head_id in ids.iter() { + self.append(*head_id as u32, id, version, embeddings) + .await?; + } + + Ok(()) + } + + pub async fn add(&self, id: u32, embedding: &[f32]) -> Result<(), SpannIndexWriterError> { + let version = self.add_versions_map(id); + // Normalize the embedding in case of cosine. + let mut normalized_embedding = embedding.to_vec(); + if self.distance_function == DistanceFunction::Cosine { + normalized_embedding = normalize(embedding); + } + // Add to the posting list. + self.add_to_postings_list(id, version, &normalized_embedding) + .await + } + + pub async fn commit(self) -> Result { + // Pl list. + let pl_flusher = match Arc::try_unwrap(self.posting_list_writer) { + Ok(writer) => writer + .into_inner() + .commit::>() + .await + .map_err(|_| SpannIndexWriterError::PostingListCommitError)?, + Err(_) => { + // This should never happen. + panic!("Failed to unwrap posting list writer"); + } + }; + // Versions map. Create a writer, write all the data and commit. + let mut bf_options = BlockfileWriterOptions::new(); + bf_options = bf_options.unordered_mutations(); + let versions_map_bf_writer = self + .blockfile_provider + .write::(bf_options) + .await + .map_err(|_| SpannIndexWriterError::VersionsMapWriterCreateError)?; + let versions_map_flusher = match Arc::try_unwrap(self.versions_map) { + Ok(writer) => { + let writer = writer.into_inner(); + for (doc_offset_id, doc_version) in writer.versions_map.into_iter() { + versions_map_bf_writer + .set("", doc_offset_id, doc_version) + .await + .map_err(|_| SpannIndexWriterError::VersionsMapSetError)?; + } + versions_map_bf_writer + .commit::() + .await + .map_err(|_| SpannIndexWriterError::VersionsMapCommitError)? + } + Err(_) => { + // This should never happen. + panic!("Failed to unwrap posting list writer"); + } + }; + // Next head. + let mut bf_options = BlockfileWriterOptions::new(); + bf_options = bf_options.unordered_mutations(); + let max_head_id_bf = self + .blockfile_provider + .write::<&str, u32>(bf_options) + .await + .map_err(|_| SpannIndexWriterError::MaxHeadIdWriterCreateError)?; + let max_head_id_flusher = match Arc::try_unwrap(self.next_head_id) { + Ok(value) => { + let value = value.into_inner(); + max_head_id_bf + .set("", MAX_HEAD_OFFSET_ID, value) + .await + .map_err(|_| SpannIndexWriterError::MaxHeadIdSetError)?; + max_head_id_bf + .commit::<&str, u32>() + .await + .map_err(|_| SpannIndexWriterError::MaxHeadIdCommitError)? + } + Err(_) => { + // This should never happen. + panic!("Failed to unwrap next head id"); + } + }; + + let hnsw_id = self.hnsw_index.inner.read().id; + + // Hnsw. + self.hnsw_provider + .commit(self.hnsw_index) + .map_err(|_| SpannIndexWriterError::HnswIndexCommitError)?; + + Ok(SpannIndexFlusher { + pl_flusher, + versions_map_flusher, + max_head_id_flusher, + hnsw_id, + hnsw_flusher: self.hnsw_provider, + }) + } +} + +pub struct SpannIndexFlusher { + pl_flusher: BlockfileFlusher, + versions_map_flusher: BlockfileFlusher, + max_head_id_flusher: BlockfileFlusher, + hnsw_id: IndexUuid, + hnsw_flusher: HnswIndexProvider, +} + +pub struct SpannIndexIds { + pub pl_id: Uuid, + pub versions_map_id: Uuid, + pub max_head_id_id: Uuid, + pub hnsw_id: IndexUuid, +} + +impl SpannIndexFlusher { + pub async fn flush(self) -> Result { + let res = SpannIndexIds { + pl_id: self.pl_flusher.id(), + versions_map_id: self.versions_map_flusher.id(), + max_head_id_id: self.max_head_id_flusher.id(), + hnsw_id: self.hnsw_id, + }; + self.pl_flusher + .flush::>() + .await + .map_err(|_| SpannIndexWriterError::PostingListFlushError)?; + self.versions_map_flusher + .flush::() + .await + .map_err(|_| SpannIndexWriterError::VersionsMapFlushError)?; + self.max_head_id_flusher + .flush::<&str, u32>() + .await + .map_err(|_| SpannIndexWriterError::MaxHeadIdFlushError)?; + self.hnsw_flusher + .flush(&self.hnsw_id) + .await + .map_err(|_| SpannIndexWriterError::HnswIndexFlushError)?; + Ok(res) + } } diff --git a/rust/worker/src/execution/operators/brute_force_knn.rs b/rust/worker/src/execution/operators/brute_force_knn.rs index ab1e0f5f2ec..b83aff738ed 100644 --- a/rust/worker/src/execution/operators/brute_force_knn.rs +++ b/rust/worker/src/execution/operators/brute_force_knn.rs @@ -1,9 +1,9 @@ use crate::execution::operator::Operator; -use crate::execution::operators::normalize_vectors::normalize; use crate::segment::record_segment::RecordSegmentReader; use crate::segment::{materialize_logs, LogMaterializerError}; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; +use chroma_distance::normalize; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::Chunk; diff --git a/rust/worker/src/execution/operators/normalize_vectors.rs b/rust/worker/src/execution/operators/normalize_vectors.rs index 631d787ff0c..03e884aa0d3 100644 --- a/rust/worker/src/execution/operators/normalize_vectors.rs +++ b/rust/worker/src/execution/operators/normalize_vectors.rs @@ -1,7 +1,6 @@ use crate::execution::operator::Operator; use async_trait::async_trait; - -const EPS: f32 = 1e-30; +use chroma_distance::normalize; #[derive(Debug)] pub struct NormalizeVectorOperator {} @@ -14,15 +13,6 @@ pub struct NormalizeVectorOperatorOutput { pub _normalized_vectors: Vec>, } -pub fn normalize(vector: &[f32]) -> Vec { - let mut norm = 0.0; - for x in vector { - norm += x * x; - } - let norm = 1.0 / (norm.sqrt() + EPS); - vector.iter().map(|x| x * norm).collect() -} - #[async_trait] impl Operator for NormalizeVectorOperator @@ -75,7 +65,7 @@ mod tests { let expected_output = NormalizeVectorOperatorOutput { _normalized_vectors: vec![ vec![0.26726124, 0.5345225, 0.8017837], - vec![0.45584232, 0.5698029, 0.6837635], + vec![0.45584232, 0.5698029, 0.68376344], vec![0.5025707, 0.5743665, 0.64616233], ], }; diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 7cc7c955815..f7dba01e825 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -17,7 +17,6 @@ use crate::execution::operators::merge_knn_results::{ MergeKnnBruteForceResultInput, MergeKnnResultsOperator, MergeKnnResultsOperatorInput, MergeKnnResultsOperatorOutput, }; -use crate::execution::operators::normalize_vectors::normalize; use crate::execution::operators::pull_log::PullLogsOutput; use crate::execution::operators::record_segment_prefetch::{ Keys, OffsetIdToDataKeys, OffsetIdToUserIdKeys, RecordSegmentPrefetchIoInput, @@ -37,6 +36,7 @@ use crate::{ }; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; +use chroma_distance::normalize; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index 379cc918749..b27f6361162 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -1,3 +1,4 @@ +use super::spann_segment::SpannSegmentWriterError; use super::types::{MaterializedLogRecord, SegmentWriter}; use super::SegmentFlusher; use async_trait::async_trait; @@ -322,6 +323,8 @@ pub enum ApplyMaterializedLogError { FullTextIndex(#[from] FullTextIndexError), #[error("Error writing to hnsw index")] HnswIndex(#[from] Box), + #[error("Error applying materialized records to spann segment: {0}")] + SpannSegmentError(#[from] SpannSegmentWriterError), } impl ChromaError for ApplyMaterializedLogError { @@ -333,6 +336,7 @@ impl ChromaError for ApplyMaterializedLogError { ApplyMaterializedLogError::Allocation => ErrorCodes::Internal, ApplyMaterializedLogError::FullTextIndex(e) => e.code(), ApplyMaterializedLogError::HnswIndex(_) => ErrorCodes::Internal, + ApplyMaterializedLogError::SpannSegmentError(e) => e.code(), } } } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index f15bee2953e..1a037a0c007 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,41 +1,61 @@ +use std::collections::HashMap; + use chroma_blockstore::provider::BlockfileProvider; +use chroma_distance::DistanceFunctionError; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter, IndexUuid}; -use chroma_types::{Segment, SegmentScope, SegmentType, SegmentUuid}; +use chroma_index::spann::types::{SpannIndexFlusher, SpannIndexWriterError}; +use chroma_index::IndexUuid; +use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; +use chroma_types::SegmentUuid; +use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; +use tonic::async_trait; use uuid::Uuid; -use super::utils::{distance_function_from_segment, hnsw_params_from_segment}; +use super::{ + record_segment::ApplyMaterializedLogError, + utils::{distance_function_from_segment, hnsw_params_from_segment}, + MaterializedLogRecord, SegmentFlusher, SegmentWriter, +}; -#[allow(dead_code)] const HNSW_PATH: &str = "hnsw_path"; -#[allow(dead_code)] const VERSION_MAP_PATH: &str = "version_map_path"; -#[allow(dead_code)] const POSTING_LIST_PATH: &str = "posting_list_path"; +const MAX_HEAD_ID_BF_PATH: &str = "max_head_id_path"; -#[allow(dead_code)] pub(crate) struct SpannSegmentWriter { index: SpannIndexWriter, + #[allow(dead_code)] id: SegmentUuid, } +// TODO(Sanket): Better error composability here. #[derive(Error, Debug)] pub enum SpannSegmentWriterError { #[error("Invalid argument")] InvalidArgument, - #[error("Distance function not found")] - DistanceFunctionNotFound, - #[error("Hnsw index id parsing error")] + #[error("Segment metadata does not contain distance function {0}")] + DistanceFunctionNotFound(#[from] DistanceFunctionError), + #[error("Error parsing index uuid from string")] IndexIdParsingError, - #[error("Hnsw Invalid file path")] + #[error("Invalid file path for HNSW index")] HnswInvalidFilePath, - #[error("Version map Invalid file path")] + #[error("Invalid file path for version map")] VersionMapInvalidFilePath, - #[error("Postings list invalid file path")] + #[error("Invalid file path for posting list")] PostingListInvalidFilePath, - #[error("Spann index creation error")] - SpannIndexWriterConstructionError, + #[error("Invalid file path for max head id")] + MaxHeadIdInvalidFilePath, + #[error("Error constructing spann index writer")] + SpannSegmentWriterCreateError, + #[error("Error adding record to spann index writer {0}")] + SpannSegmentWriterAddRecordError(#[from] SpannIndexWriterError), + #[error("Error committing spann index writer")] + SpannSegmentWriterCommitError, + #[error("Error flushing spann index writer")] + SpannSegmentWriterFlushError, + #[error("Not implemented")] + NotImplemented, } impl ChromaError for SpannSegmentWriterError { @@ -43,11 +63,16 @@ impl ChromaError for SpannSegmentWriterError { match self { Self::InvalidArgument => ErrorCodes::InvalidArgument, Self::IndexIdParsingError => ErrorCodes::Internal, - Self::DistanceFunctionNotFound => ErrorCodes::Internal, + Self::DistanceFunctionNotFound(e) => e.code(), Self::HnswInvalidFilePath => ErrorCodes::Internal, Self::VersionMapInvalidFilePath => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, - Self::SpannIndexWriterConstructionError => ErrorCodes::Internal, + Self::SpannSegmentWriterCreateError => ErrorCodes::Internal, + Self::MaxHeadIdInvalidFilePath => ErrorCodes::Internal, + Self::NotImplemented => ErrorCodes::Internal, + Self::SpannSegmentWriterCommitError => ErrorCodes::Internal, + Self::SpannSegmentWriterFlushError => ErrorCodes::Internal, + Self::SpannSegmentWriterAddRecordError(e) => e.code(), } } } @@ -65,8 +90,8 @@ impl SpannSegmentWriter { } let distance_function = match distance_function_from_segment(segment) { Ok(distance_function) => distance_function, - Err(_) => { - return Err(SpannSegmentWriterError::DistanceFunctionNotFound); + Err(e) => { + return Err(SpannSegmentWriterError::DistanceFunctionNotFound(*e)); } }; let (hnsw_id, m, ef_construction, ef_search) = match segment.file_path.get(HNSW_PATH) { @@ -78,19 +103,21 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; - let hnsw_params = hnsw_params_from_segment(segment); - ( - Some(IndexUuid(index_uuid)), - Some(hnsw_params.m), - Some(hnsw_params.ef_construction), - Some(hnsw_params.ef_search), - ) + (Some(IndexUuid(index_uuid)), None, None, None) } None => { return Err(SpannSegmentWriterError::HnswInvalidFilePath); } }, - None => (None, None, None, None), + None => { + let hnsw_params = hnsw_params_from_segment(segment); + ( + None, + Some(hnsw_params.m), + Some(hnsw_params.ef_construction), + Some(hnsw_params.ef_search), + ) + } }; let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) { Some(version_map_path) => match version_map_path.first() { @@ -127,11 +154,30 @@ impl SpannSegmentWriter { None => None, }; + let max_head_id_bf_id = match segment.file_path.get(MAX_HEAD_ID_BF_PATH) { + Some(max_head_id_bf_path) => match max_head_id_bf_path.first() { + Some(max_head_id_bf_id) => { + let max_head_id_bf_uuid = match Uuid::parse_str(max_head_id_bf_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentWriterError::IndexIdParsingError); + } + }; + Some(max_head_id_bf_uuid) + } + None => { + return Err(SpannSegmentWriterError::MaxHeadIdInvalidFilePath); + } + }, + None => None, + }; + let index_writer = match SpannIndexWriter::from_id( hnsw_provider, hnsw_id.as_ref(), versions_map_id.as_ref(), posting_list_id.as_ref(), + max_head_id_bf_id.as_ref(), m, ef_construction, ef_search, @@ -144,7 +190,7 @@ impl SpannSegmentWriter { { Ok(index_writer) => index_writer, Err(_) => { - return Err(SpannSegmentWriterError::SpannIndexWriterConstructionError); + return Err(SpannSegmentWriterError::SpannSegmentWriterCreateError); } }; @@ -153,4 +199,287 @@ impl SpannSegmentWriter { id: segment.id, }) } + + async fn add(&self, record: &MaterializedLogRecord<'_>) -> Result<(), SpannSegmentWriterError> { + self.index + .add(record.offset_id, record.merged_embeddings()) + .await + .map_err(SpannSegmentWriterError::SpannSegmentWriterAddRecordError) + } +} + +struct SpannSegmentFlusher { + index_flusher: SpannIndexFlusher, +} + +impl<'referred_data> SegmentWriter<'referred_data> for SpannSegmentWriter { + async fn apply_materialized_log_chunk( + &self, + records: chroma_types::Chunk>, + ) -> Result<(), ApplyMaterializedLogError> { + for (record, _) in records.iter() { + match record.final_operation { + MaterializedLogOperation::AddNew => { + self.add(record) + .await + .map_err(ApplyMaterializedLogError::SpannSegmentError)?; + } + // TODO(Sanket): Implement other operations. + _ => { + todo!() + } + } + } + Ok(()) + } + + async fn commit(self) -> Result> { + let index_flusher = self + .index + .commit() + .await + .map_err(|_| SpannSegmentWriterError::SpannSegmentWriterCommitError); + match index_flusher { + Err(e) => Err(Box::new(e)), + Ok(index_flusher) => Ok(SpannSegmentFlusher { index_flusher }), + } + } +} + +#[async_trait] +impl SegmentFlusher for SpannSegmentFlusher { + async fn flush(self) -> Result>, Box> { + let index_flusher_res = self + .index_flusher + .flush() + .await + .map_err(|_| SpannSegmentWriterError::SpannSegmentWriterFlushError); + match index_flusher_res { + Err(e) => Err(Box::new(e)), + Ok(index_ids) => { + let mut index_id_map = HashMap::new(); + index_id_map.insert(HNSW_PATH.to_string(), vec![index_ids.hnsw_id.to_string()]); + index_id_map.insert( + VERSION_MAP_PATH.to_string(), + vec![index_ids.versions_map_id.to_string()], + ); + index_id_map.insert( + POSTING_LIST_PATH.to_string(), + vec![index_ids.pl_id.to_string()], + ); + index_id_map.insert( + MAX_HEAD_ID_BF_PATH.to_string(), + vec![index_ids.max_head_id_id.to_string()], + ); + Ok(index_id_map) + } + } + } +} + +#[cfg(test)] +mod test { + use std::{collections::HashMap, path::PathBuf}; + + use chroma_blockstore::{ + arrow::{config::TEST_MAX_BLOCK_SIZE_BYTES, provider::ArrowBlockfileProvider}, + provider::BlockfileProvider, + }; + use chroma_cache::{new_cache_for_test, new_non_persistent_cache_for_test}; + use chroma_distance::DistanceFunction; + use chroma_index::{hnsw_provider::HnswIndexProvider, Index}; + use chroma_storage::{local::LocalStorage, Storage}; + use chroma_types::{ + Chunk, CollectionUuid, LogRecord, Metadata, MetadataValue, Operation, OperationRecord, + SegmentUuid, SpannPostingList, + }; + + use crate::segment::{ + materialize_logs, spann_segment::SpannSegmentWriter, SegmentFlusher, SegmentWriter, + }; + + #[tokio::test] + async fn test_spann_segment_writer() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let collection_id = CollectionUuid::new(); + let segment_id = SegmentUuid::new(); + let mut metadata_hash_map = Metadata::new(); + metadata_hash_map.insert( + "hnsw:space".to_string(), + MetadataValue::Str("l2".to_string()), + ); + metadata_hash_map.insert("hnsw:M".to_string(), MetadataValue::Int(16)); + metadata_hash_map.insert("hnsw:construction_ef".to_string(), MetadataValue::Int(100)); + metadata_hash_map.insert("hnsw:search_ef".to_string(), MetadataValue::Int(100)); + let mut spann_segment = chroma_types::Segment { + id: segment_id, + collection: collection_id, + r#type: chroma_types::SegmentType::Spann, + scope: chroma_types::SegmentScope::VECTOR, + metadata: Some(metadata_hash_map), + file_path: HashMap::new(), + }; + let spann_writer = SpannSegmentWriter::from_segment( + &spann_segment, + &blockfile_provider, + &hnsw_provider, + 3, + ) + .await + .expect("Error creating spann segment writer"); + let data = vec![ + LogRecord { + log_offset: 1, + record: OperationRecord { + id: "embedding_id_1".to_string(), + embedding: Some(vec![1.0, 2.0, 3.0]), + encoding: None, + metadata: None, + document: Some(String::from("This is a document about cats.")), + operation: Operation::Add, + }, + }, + LogRecord { + log_offset: 2, + record: OperationRecord { + id: "embedding_id_2".to_string(), + embedding: Some(vec![4.0, 5.0, 6.0]), + encoding: None, + metadata: None, + document: Some(String::from("This is a document about dogs.")), + operation: Operation::Add, + }, + }, + ]; + let chunked_log = Chunk::new(data.into()); + // Materialize the logs. + let materialized_log = materialize_logs(&None, &chunked_log, None) + .await + .expect("Error materializing logs"); + spann_writer + .apply_materialized_log_chunk(materialized_log) + .await + .expect("Error applying materialized log"); + let flusher = spann_writer + .commit() + .await + .expect("Error committing spann writer"); + spann_segment.file_path = flusher.flush().await.expect("Error flushing spann writer"); + assert_eq!(spann_segment.file_path.len(), 4); + assert!(spann_segment.file_path.contains_key("hnsw_path")); + assert!(spann_segment.file_path.contains_key("version_map_path"),); + assert!(spann_segment.file_path.contains_key("posting_list_path"),); + assert!(spann_segment.file_path.contains_key("max_head_id_path"),); + // Load this segment and check if the embeddings are present. New cache + // so that the previous cache is not used. + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage, + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let spann_writer = SpannSegmentWriter::from_segment( + &spann_segment, + &blockfile_provider, + &hnsw_provider, + 3, + ) + .await + .expect("Error creating spann segment writer"); + assert_eq!(spann_writer.index.dimensionality, 3); + assert_eq!( + spann_writer.index.distance_function, + DistanceFunction::Euclidean + ); + // Next head id should be 2 since one centroid is already taken up. + assert_eq!( + spann_writer + .index + .next_head_id + .load(std::sync::atomic::Ordering::SeqCst), + 2 + ); + { + let read_guard = spann_writer.index.versions_map.read(); + assert_eq!(read_guard.versions_map.len(), 2); + assert_eq!( + *read_guard + .versions_map + .get(&1) + .expect("Doc offset id 1 not found"), + 1 + ); + assert_eq!( + *read_guard + .versions_map + .get(&2) + .expect("Doc offset id 2 not found"), + 1 + ); + } + { + // Test HNSW. + let hnsw_index = spann_writer.index.hnsw_index.inner.read(); + assert_eq!(hnsw_index.len(), 1); + let r = hnsw_index + .get(1) + .expect("Expect one centroid") + .expect("Expect centroid embedding"); + assert_eq!(r.len(), 3); + assert_eq!(r[0], 1.0); + assert_eq!(r[1], 2.0); + assert_eq!(r[2], 3.0); + } + // Test PL. + let read_guard = spann_writer.index.posting_list_writer.lock().await; + let res = read_guard + .get_owned::>("", 1) + .await + .expect("Expected posting list to be present") + .expect("Expected posting list to be present"); + assert_eq!(res.0.len(), 2); + assert_eq!(res.1.len(), 2); + assert_eq!(res.2.len(), 6); + assert_eq!(res.0[0], 1); + assert_eq!(res.0[1], 2); + assert_eq!(res.1[0], 1); + assert_eq!(res.1[1], 1); + assert_eq!(res.2[0], 1.0); + assert_eq!(res.2[1], 2.0); + assert_eq!(res.2[2], 3.0); + assert_eq!(res.2[3], 4.0); + assert_eq!(res.2[4], 5.0); + assert_eq!(res.2[5], 6.0); + } }