Skip to content

Commit

Permalink
Add reassign test
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Dec 4, 2024
1 parent 8505627 commit a929726
Showing 1 changed file with 241 additions and 0 deletions.
241 changes: 241 additions & 0 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1965,4 +1965,245 @@ mod tests {
assert_eq!(pl.2.len(), 158);
}
}

#[tokio::test]
async fn test_reassign() {
let tmp_dir = tempfile::tempdir().unwrap();
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = new_cache_for_test();
let sparse_index_cache = new_cache_for_test();
let arrow_blockfile_provider = ArrowBlockfileProvider::new(
storage.clone(),
TEST_MAX_BLOCK_SIZE_BYTES,
block_cache,
sparse_index_cache,
);
let blockfile_provider =
BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider);
let hnsw_cache = new_non_persistent_cache_for_test();
let (_, rx) = tokio::sync::mpsc::unbounded_channel();
let hnsw_provider = HnswIndexProvider::new(
storage.clone(),
PathBuf::from(tmp_dir.path().to_str().unwrap()),
hnsw_cache,
rx,
);
let m = 16;
let ef_construction = 200;
let ef_search = 200;
let collection_id = CollectionUuid::new();
let distance_function = chroma_distance::DistanceFunction::Euclidean;
let dimensionality = 2;
let writer = SpannIndexWriter::from_id(
&hnsw_provider,
None,
None,
None,
None,
Some(m),
Some(ef_construction),
Some(ef_search),
&collection_id,
distance_function,
dimensionality,
&blockfile_provider,
)
.await
.expect("Error creating spann index writer");
// Create three centers with ill placed points.
{
let hnsw_guard = writer.hnsw_index.inner.write();
hnsw_guard
.add(1, &[0.0, 0.0])
.expect("Error adding to hnsw index");
hnsw_guard
.add(2, &[1000.0, 1000.0])
.expect("Error adding to hnsw index");
hnsw_guard
.add(3, &[10000.0, 10000.0])
.expect("Error adding to hnsw index");
}
// Insert 50 points within a radius of 1 to center 1.
let mut split_doc_offset_ids1 = vec![0u32; 50];
let mut split_doc_versions1 = vec![0u32; 50];
let mut split_doc_embeddings1 = vec![0.0; 100];
let mut split_doc_offset_ids2 = vec![0u32; 50];
let mut split_doc_versions2 = vec![0u32; 50];
let mut split_doc_embeddings2 = vec![0.0; 100];
let mut split_doc_offset_ids3 = vec![0u32; 50];
let mut split_doc_versions3 = vec![0u32; 50];
let mut split_doc_embeddings3 = vec![0.0; 100];
{
let mut rng = rand::thread_rng();
let pl_guard = writer.posting_list_writer.lock().await;
for i in 1..=50 {
// Generate random radius between 0 and 1
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution

// Generate random angle between 0 and 2π
let theta = rng.gen::<f32>() * 2.0 * PI;

// Convert to Cartesian coordinates
let x = r * theta.cos();
let y = r * theta.sin();

split_doc_offset_ids1[i - 1] = i as u32;
split_doc_versions1[i - 1] = 1;
split_doc_embeddings1[(i - 1) * 2] = x;
split_doc_embeddings1[(i - 1) * 2 + 1] = y;
}
let posting_list = SpannPostingList {
doc_offset_ids: &split_doc_offset_ids1,
doc_versions: &split_doc_versions1,
doc_embeddings: &split_doc_embeddings1,
};
pl_guard
.set("", 1, &posting_list)
.await
.expect("Error writing to posting list");
// Insert 50 points within a radius of 1 to center 3 to center 2 and vice versa.
// This ensures that we test reassignment and that it shuffles the two fully.
for i in 1..=50 {
// Generate random radius between 0 and 1
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution

// Generate random angle between 0 and 2π
let theta = rng.gen::<f32>() * 2.0 * PI;

// Convert to Cartesian coordinates
let x = r * theta.cos() + 1000.0;
let y = r * theta.sin() + 1000.0;

split_doc_offset_ids3[i - 1] = 50 + i as u32;
split_doc_versions3[i - 1] = 1;
split_doc_embeddings3[(i - 1) * 2] = x;
split_doc_embeddings3[(i - 1) * 2 + 1] = y;
}
let posting_list = SpannPostingList {
doc_offset_ids: &split_doc_offset_ids3,
doc_versions: &split_doc_versions3,
doc_embeddings: &split_doc_embeddings3,
};
pl_guard
.set("", 3, &posting_list)
.await
.expect("Error writing to posting list");
// Do the same for 10000.
for i in 1..=50 {
// Generate random radius between 0 and 1
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution

// Generate random angle between 0 and 2π
let theta = rng.gen::<f32>() * 2.0 * PI;

// Convert to Cartesian coordinates
let x = r * theta.cos() + 10000.0;
let y = r * theta.sin() + 10000.0;

split_doc_offset_ids2[i - 1] = 100 + i as u32;
split_doc_versions2[i - 1] = 1;
split_doc_embeddings2[(i - 1) * 2] = x;
split_doc_embeddings2[(i - 1) * 2 + 1] = y;
}
let posting_list = SpannPostingList {
doc_offset_ids: &split_doc_offset_ids2,
doc_versions: &split_doc_versions2,
doc_embeddings: &split_doc_embeddings2,
};
pl_guard
.set("", 2, &posting_list)
.await
.expect("Error writing to posting list");
}
// Insert these 150 points to version map.
{
let mut version_map_guard = writer.versions_map.write();
for i in 1..=150 {
version_map_guard.versions_map.insert(i as u32, 1);
}
}
// Trigger reassign and see the results.
// Carefully construct the old head embedding so that NPA
// is violated for the second center.
writer
.collect_and_reassign(
&[1, 2],
&[Some(&vec![0.0, 0.0]), Some(&vec![1000.0, 1000.0])],
&[5000.0, 5000.0],
&[split_doc_offset_ids1.clone(), split_doc_offset_ids2.clone()],
&[split_doc_versions1.clone(), split_doc_versions2.clone()],
&[split_doc_embeddings1.clone(), split_doc_embeddings2.clone()],
)
.await
.expect("Expected reassign to succeed");
// See the reassigned points.
{
let pl_guard = writer.posting_list_writer.lock().await;
// Center 1 should remain unchanged.
let pl = pl_guard
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 50);
assert_eq!(pl.1.len(), 50);
assert_eq!(pl.2.len(), 100);
for i in 1..=50 {
assert_eq!(pl.0[i - 1], i as u32);
assert_eq!(pl.1[i - 1], 1);
assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings1[(i - 1) * 2]);
assert_eq!(
pl.2[(i - 1) * 2 + 1],
split_doc_embeddings1[(i - 1) * 2 + 1]
);
}
// Center 2 should get 50 points, all with version 2 migrating from center 3.
let pl = pl_guard
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 50);
assert_eq!(pl.1.len(), 50);
assert_eq!(pl.2.len(), 100);
for i in 1..=50 {
assert_eq!(pl.0[i - 1], 50 + i as u32);
assert_eq!(pl.1[i - 1], 2);
assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]);
assert_eq!(
pl.2[(i - 1) * 2 + 1],
split_doc_embeddings3[(i - 1) * 2 + 1]
);
}
// Center 3 should get 100 points. 50 points with version 1 which weere
// originally in center 3 and 50 points with version 2 which were originally
// in center 2.
let pl = pl_guard
.get_owned::<u32, &SpannPostingList<'_>>("", 3)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 100);
assert_eq!(pl.1.len(), 100);
assert_eq!(pl.2.len(), 200);
for i in 1..=100 {
assert_eq!(pl.0[i - 1], 50 + i as u32);
if i <= 50 {
assert_eq!(pl.1[i - 1], 1);
assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]);
assert_eq!(
pl.2[(i - 1) * 2 + 1],
split_doc_embeddings3[(i - 1) * 2 + 1]
);
} else {
assert_eq!(pl.1[i - 1], 2);
assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings2[(i - 51) * 2]);
assert_eq!(
pl.2[(i - 1) * 2 + 1],
split_doc_embeddings2[(i - 51) * 2 + 1]
);
}
}
}
}
}

0 comments on commit a929726

Please sign in to comment.