diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 6c116e5ae884..d30476326cf8 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1114,7 +1114,7 @@ impl SpannIndexWriter { let (nearest_head_ids, _, nearest_head_embeddings) = self .get_nearby_heads(head_embedding, NUM_CENTERS_TO_MERGE_TO) .await?; - for (nearest_head_id, head_embedding) in nearest_head_ids + for (nearest_head_id, nearest_head_embedding) in nearest_head_ids .into_iter() .zip(nearest_head_embeddings.into_iter()) { @@ -1183,7 +1183,7 @@ impl SpannIndexWriter { } // This center is now merged with a neighbor. target_head = nearest_head_id; - target_embedding = head_embedding; + target_embedding = nearest_head_embedding; merged_with_a_nbr = true; break; } @@ -2206,4 +2206,244 @@ mod tests { } } } + + #[tokio::test] + async fn test_reassign_merge() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Create three centers. 2 of these are accurate wrt their centers and third + // is ill placed. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(3, &[10000.0, 10000.0]) + .expect("Error adding to hnsw index"); + } + let mut doc_offset_ids1 = vec![0u32; 70]; + let mut doc_versions1 = vec![0u32; 70]; + let mut doc_embeddings1 = vec![0.0; 140]; + let mut doc_offset_ids2 = vec![0u32; 20]; + let mut doc_versions2 = vec![0u32; 20]; + let mut doc_embeddings2 = vec![0.0; 40]; + let mut doc_offset_ids3 = vec![0u32; 70]; + let mut doc_versions3 = vec![0u32; 70]; + let mut doc_embeddings3 = vec![0.0; 140]; + { + let mut rng = rand::thread_rng(); + let pl_guard = writer.posting_list_writer.lock().await; + // Insert 70 points within a radius of 1 to center 1. + for i in 1..=70 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos(); + let y = r * theta.sin(); + + doc_offset_ids1[i - 1] = i as u32; + doc_versions1[i - 1] = 1; + doc_embeddings1[(i - 1) * 2] = x; + doc_embeddings1[(i - 1) * 2 + 1] = y; + } + // Insert 20 points within a radius of 1 to center 2. + for i in 71..=90 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 10000.0; + let y = r * theta.sin() + 10000.0; + + doc_offset_ids2[i - 71] = i as u32; + doc_versions2[i - 71] = 1; + doc_embeddings2[(i - 71) * 2] = x; + doc_embeddings2[(i - 71) * 2 + 1] = y; + } + // Insert 70 points within a radius of 1 to center 3. + for i in 91..=160 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 10000.0; + let y = r * theta.sin() + 10000.0; + + doc_offset_ids3[i - 91] = i as u32; + doc_versions3[i - 91] = 1; + doc_embeddings3[(i - 91) * 2] = x; + doc_embeddings3[(i - 91) * 2 + 1] = y; + } + let spann_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids1, + doc_versions: &doc_versions1, + doc_embeddings: &doc_embeddings1, + }; + pl_guard + .set("", 1, &spann_posting_list) + .await + .expect("Error writing to posting list"); + let spann_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids2, + doc_versions: &doc_versions2, + doc_embeddings: &doc_embeddings2, + }; + pl_guard + .set("", 2, &spann_posting_list) + .await + .expect("Error writing to posting list"); + let spann_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids3, + doc_versions: &doc_versions3, + doc_embeddings: &doc_embeddings3, + }; + pl_guard + .set("", 3, &spann_posting_list) + .await + .expect("Error writing to posting list"); + } + // Initialize the versions map appropriately. + { + let mut version_map_guard = writer.versions_map.write(); + for i in 1..=160 { + version_map_guard.versions_map.insert(i as u32, 1); + } + } + // Run a GC now. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // Run GC again to clean up the outdated points. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // check the posting lists. + { + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 70); + assert_eq!(pl.1.len(), 70); + assert_eq!(pl.2.len(), 140); + for point in 1..=70 { + assert_eq!(pl.0[point - 1], point as u32); + assert_eq!(pl.1[point - 1], 1); + assert_eq!(pl.2[(point - 1) * 2], doc_embeddings1[(point - 1) * 2]); + assert_eq!( + pl.2[(point - 1) * 2 + 1], + doc_embeddings1[(point - 1) * 2 + 1] + ); + } + let pl = pl_guard + .get_owned::>("", 3) + .await + .expect("Error getting posting list") + .unwrap(); + // PL3 should be 90. + assert_eq!(pl.0.len(), 90); + assert_eq!(pl.1.len(), 90); + assert_eq!(pl.2.len(), 180); + for point in 1..=70 { + assert_eq!(pl.0[point - 1], 90 + point as u32); + assert_eq!(pl.1[point - 1], 1); + assert_eq!(pl.2[(point - 1) * 2], doc_embeddings3[(point - 1) * 2]); + assert_eq!( + pl.2[(point - 1) * 2 + 1], + doc_embeddings3[(point - 1) * 2 + 1] + ); + } + for point in 71..=90 { + assert_eq!(pl.0[point - 1], point as u32); + assert_eq!(pl.1[point - 1], 2); + assert_eq!(pl.2[(point - 1) * 2], doc_embeddings2[(point - 71) * 2]); + assert_eq!( + pl.2[(point - 1) * 2 + 1], + doc_embeddings2[(point - 71) * 2 + 1] + ); + } + } + // There should only be two heads. + { + let hnsw_read_guard = writer.hnsw_index.inner.read(); + assert_eq!(hnsw_read_guard.len(), 2); + let (mut non_deleted_ids, deleted_ids) = hnsw_read_guard + .get_all_ids() + .expect("Error getting all ids"); + non_deleted_ids.sort(); + assert_eq!(non_deleted_ids.len(), 2); + assert_eq!(deleted_ids.len(), 1); + assert_eq!(non_deleted_ids[0], 1); + assert_eq!(non_deleted_ids[1], 3); + assert_eq!(deleted_ids[0], 2); + let emb = hnsw_read_guard + .get(non_deleted_ids[0]) + .expect("Error getting hnsw index") + .unwrap(); + assert_eq!(emb, &[0.0, 0.0]); + let emb = hnsw_read_guard + .get(non_deleted_ids[1]) + .expect("Error getting hnsw index") + .unwrap(); + assert_eq!(emb, &[10000.0, 10000.0]); + } + } }