diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index c5b315c6a30f..453b3596a52a 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,10 +1,11 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use arrow::error; use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter}; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::SpannPostingList; +use parking_lot::RwLock; use thiserror::Error; use uuid::Uuid; @@ -19,7 +20,8 @@ pub struct SpannIndexWriter { // The blockfile also contains next id for the head. posting_list_writer: BlockfileWriter, // Version number of each point. - versions_map: HashMap, + // TODO(Sanket): Finer grained locking for this map in future. + versions_map: Arc>>, } #[derive(Error, Debug)] @@ -53,7 +55,7 @@ impl SpannIndexWriter { hnsw_index, hnsw_provider, posting_list_writer, - versions_map, + versions_map: Arc::new(RwLock::new(versions_map)), } } @@ -194,4 +196,11 @@ impl SpannIndexWriter { versions_map, )) } + + pub fn add_versions_map(&self, id: u32) { + // 0 means deleted. Version counting starts from 1. + self.versions_map.write().insert(id, 1); + } + + pub async fn add_new_record_to_postings_list(&self, id: u32, embeddings: &[f32]) {} } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 2daaf8bd16c7..d44a3261d280 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -4,11 +4,15 @@ use arrow::error; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; -use chroma_types::{Segment, SegmentScope, SegmentType}; +use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; use uuid::Uuid; -use super::utils::{distance_function_from_segment, hnsw_params_from_segment}; +use super::{ + record_segment::ApplyMaterializedLogError, + utils::{distance_function_from_segment, hnsw_params_from_segment}, + MaterializedLogRecord, SegmentFlusher, SegmentWriter, +}; const HNSW_PATH: &str = "hnsw_path"; const VERSION_MAP_PATH: &str = "version_map_path"; @@ -143,4 +147,35 @@ impl SpannSegmentWriter { id: segment.id, }) } + + async fn add(&self, record: &MaterializedLogRecord<'_>) { + // Initialize the record with a version. + self.index.add_new_record_to_versions_map(record.offset_id); + self.index + .add_new_record_to_postings_list(record.offset_id, record.merged_embeddings()); + } +} + +impl<'a> SegmentWriter<'a> for SpannSegmentWriter { + async fn apply_materialized_log_chunk( + &self, + records: chroma_types::Chunk>, + ) -> Result<(), ApplyMaterializedLogError> { + for (record, idx) in records.iter() { + match record.final_operation { + MaterializedLogOperation::AddNew => { + self.add(record).await; + } + // TODO(Sanket): Implement other operations. + _ => { + todo!() + } + } + } + Ok(()) + } + + async fn commit(self) -> Result> { + todo!() + } }