diff --git a/rust/index/src/hnsw_provider.rs b/rust/index/src/hnsw_provider.rs index cc1c510d941..a69d7778731 100644 --- a/rust/index/src/hnsw_provider.rs +++ b/rust/index/src/hnsw_provider.rs @@ -33,11 +33,7 @@ const FILES: [&str; 4] = [ "link_lists.bin", ]; -pub type HnswIndexParams = ( - usize, /* m */ - usize, /* ef_construction */ - usize, /* ef_search */ -); +type CacheKey = CollectionUuid; // The key of the cache is the collection id and the value is // the HNSW index for that collection. This restricts the cache to @@ -132,12 +128,8 @@ impl HnswIndexProvider { } } - pub async fn get( - &self, - index_id: &IndexUuid, - collection_id: &CollectionUuid, - ) -> Option { - match self.cache.get(collection_id).await.ok().flatten() { + pub async fn get(&self, index_id: &IndexUuid, cache_key: &CacheKey) -> Option { + match self.cache.get(cache_key).await.ok().flatten() { Some(index) => { let index_with_lock = index.inner.read(); if index_with_lock.id == *index_id { @@ -158,7 +150,7 @@ impl HnswIndexProvider { pub async fn fork( &self, source_id: &IndexUuid, - collection_id: &CollectionUuid, + cache_key: &CacheKey, dimensionality: i32, distance_function: DistanceFunction, ) -> Result> { @@ -197,13 +189,13 @@ impl HnswIndexProvider { match HnswIndex::load(storage_path_str, &index_config, new_id) { Ok(index) => { let _guard = self.write_mutex.lock().await; - match self.get(&new_id, collection_id).await { + match self.get(&new_id, cache_key).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(*collection_id, index.clone()).await; + self.cache.insert(*cache_key, index.clone()).await; Ok(index) } } @@ -288,7 +280,7 @@ impl HnswIndexProvider { pub async fn open( &self, id: &IndexUuid, - collection_id: &CollectionUuid, + cache_key: &CacheKey, dimensionality: i32, distance_function: DistanceFunction, ) -> Result> { @@ -327,13 +319,13 @@ impl HnswIndexProvider { match HnswIndex::load(index_storage_path_str, &index_config, *id) { Ok(index) => { let _guard = self.write_mutex.lock().await; - match self.get(id, collection_id).await { + match self.get(id, cache_key).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(*collection_id, index.clone()).await; + self.cache.insert(*cache_key, index.clone()).await; Ok(index) } } @@ -354,9 +346,10 @@ impl HnswIndexProvider { // A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id pub async fn create( &self, - collection_id: &CollectionUuid, - hnsw_params: HnswIndexParams, - persist_path: &std::path::Path, + cache_key: &CacheKey, + m: usize, + ef_construction: usize, + ef_search: usize, dimensionality: i32, distance_function: DistanceFunction, ) -> Result> { @@ -373,7 +366,7 @@ impl HnswIndexProvider { let index_config = IndexConfig::new(dimensionality, distance_function); let hnsw_config = - match HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, persist_path) { + match HnswIndexConfig::new(m, ef_construction, ef_search, &index_storage_path) { Ok(hnsw_config) => hnsw_config, Err(e) => { return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e))); @@ -385,13 +378,13 @@ impl HnswIndexProvider { .map_err(|e| Box::new(HnswIndexProviderCreateError::IndexInitError(e)))?; let _guard = self.write_mutex.lock().await; - match self.get(&id, collection_id).await { + match self.get(&id, cache_key).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(*collection_id, index.clone()).await; + self.cache.insert(*cache_key, index.clone()).await; Ok(index) } } @@ -430,8 +423,8 @@ impl HnswIndexProvider { } /// Purge entries from the cache by index ID and remove temporary files from disk. - pub async fn purge_by_id(&mut self, collection_uuids: &[CollectionUuid]) { - for collection_uuid in collection_uuids { + pub async fn purge_by_id(&mut self, cache_keys: &[CacheKey]) { + for collection_uuid in cache_keys { let Some(index_id) = self .cache .get(collection_uuid) @@ -615,17 +608,13 @@ mod tests { let collection_id = CollectionUuid(Uuid::new_v4()); let dimensionality = 128; - let hnsw_params = ( - DEFAULT_HNSW_M, - DEFAULT_HNSW_EF_CONSTRUCTION, - DEFAULT_HNSW_EF_SEARCH, - ); let distance_function = DistanceFunction::Euclidean; let created_index = provider .create( &collection_id, - hnsw_params, - &provider.temporary_storage_path, + DEFAULT_HNSW_M, + DEFAULT_HNSW_EF_CONSTRUCTION, + DEFAULT_HNSW_EF_SEARCH, dimensionality, distance_function.clone(), ) diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index c79d32028b9..bff4a208b86 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use chroma_distance::{DistanceFunction, DistanceFunctionError}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::{ - HnswIndexParams, HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, + HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexProviderOpenError, HnswIndexRef, }; use chroma_index::{Index, IndexUuid}; @@ -18,6 +18,12 @@ use uuid::Uuid; const HNSW_INDEX: &str = "hnsw_index"; +pub struct HnswIndexParamsFromSegment { + pub m: usize, + pub ef_construction: usize, + pub ef_search: usize, +} + #[derive(Clone)] pub(crate) struct DistributedHNSWSegmentWriter { index: HnswIndexRef, @@ -65,15 +71,15 @@ impl ChromaError for DistributedHNSWSegmentFromSegmentError { } } -fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { +fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParamsFromSegment { let metadata = match &segment.metadata { Some(metadata) => metadata, None => { - return ( - DEFAULT_HNSW_M, - DEFAULT_HNSW_EF_CONSTRUCTION, - DEFAULT_HNSW_EF_SEARCH, - ); + return HnswIndexParamsFromSegment { + m: DEFAULT_HNSW_M, + ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, + ef_search: DEFAULT_HNSW_EF_SEARCH, + }; } }; @@ -90,7 +96,11 @@ fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { Err(_) => DEFAULT_HNSW_EF_SEARCH, }; - (m, ef_construction, ef_search) + HnswIndexParamsFromSegment { + m, + ef_construction, + ef_search, + } } pub fn distance_function_from_segment( @@ -130,7 +140,6 @@ impl DistributedHNSWSegmentWriter { hnsw_index_provider: HnswIndexProvider, ) -> Result, Box> { - let persist_path = &hnsw_index_provider.temporary_storage_path; // TODO: this is hacky, we use the presence of files to determine if we need to load or create the index // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this @@ -205,8 +214,9 @@ impl DistributedHNSWSegmentWriter { let index = match hnsw_index_provider .create( &segment.collection, - hnsw_params, - persist_path, + hnsw_params.m, + hnsw_params.ef_construction, + hnsw_params.ef_search, dimensionality as i32, distance_function, ) @@ -445,9 +455,13 @@ pub mod test { }; let hnsw_params = hnsw_params_from_segment(&segment); - let config = - HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, &persist_path) - .expect("Error creating hnsw index config"); + let config = HnswIndexConfig::new( + hnsw_params.m, + hnsw_params.ef_construction, + hnsw_params.ef_search, + &persist_path, + ) + .expect("Error creating hnsw index config"); assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); assert_eq!(config.m, DEFAULT_HNSW_M); @@ -470,9 +484,13 @@ pub mod test { }; let hnsw_params = hnsw_params_from_segment(&segment); - let config = - HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, &persist_path) - .expect("Error creating hnsw index config"); + let config = HnswIndexConfig::new( + hnsw_params.m, + hnsw_params.ef_construction, + hnsw_params.ef_search, + &persist_path, + ) + .expect("Error creating hnsw index config"); assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); assert_eq!(config.m, 10);