From 5a588a9046b8dc3dfbd74acd8fcabbb36ba49523 Mon Sep 17 00:00:00 2001 From: Macronova <60079945+Sicheng-Pan@users.noreply.github.com> Date: Wed, 11 Dec 2024 20:25:31 +0000 Subject: [PATCH] [CLN] Cleanup query node after pushdown (#3280) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Remove unused impls in query node after query pushdown landed - New functionality - N/A ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* N/A --- rust/worker/src/execution/operator.rs | 2 + .../execution/operators/brute_force_knn.rs | 533 ----------- .../operators/get_vectors_operator.rs | 179 ---- .../src/execution/operators/hnsw_knn.rs | 248 ----- .../execution/operators/merge_knn_results.rs | 289 ------ rust/worker/src/execution/operators/mod.rs | 6 - .../execution/operators/normalize_vectors.rs | 81 -- .../operators/record_segment_prefetch.rs | 128 --- .../src/execution/orchestration/common.rs | 157 +--- .../execution/orchestration/get_vectors.rs | 348 ------- .../src/execution/orchestration/hnsw.rs | 864 ------------------ .../worker/src/execution/orchestration/mod.rs | 3 - rust/worker/src/segment/record_segment.rs | 20 - 13 files changed, 4 insertions(+), 2854 deletions(-) delete mode 100644 rust/worker/src/execution/operators/brute_force_knn.rs delete mode 100644 rust/worker/src/execution/operators/get_vectors_operator.rs delete mode 100644 rust/worker/src/execution/operators/hnsw_knn.rs delete mode 100644 rust/worker/src/execution/operators/merge_knn_results.rs delete mode 100644 rust/worker/src/execution/operators/normalize_vectors.rs delete mode 100644 rust/worker/src/execution/operators/record_segment_prefetch.rs delete mode 100644 rust/worker/src/execution/orchestration/get_vectors.rs delete mode 100644 rust/worker/src/execution/orchestration/hnsw.rs diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index aed4570721a..d82aaaec501 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -73,6 +73,7 @@ impl TaskResult { self.result } + #[allow(dead_code)] pub(super) fn id(&self) -> Uuid { self.task_id } @@ -101,6 +102,7 @@ pub(crate) type TaskMessage = Box; pub(crate) trait TaskWrapper: Send + Debug { fn get_name(&self) -> &'static str; async fn run(&self); + #[allow(dead_code)] fn id(&self) -> Uuid; fn get_type(&self) -> OperatorType; } diff --git a/rust/worker/src/execution/operators/brute_force_knn.rs b/rust/worker/src/execution/operators/brute_force_knn.rs deleted file mode 100644 index b83aff738ed..00000000000 --- a/rust/worker/src/execution/operators/brute_force_knn.rs +++ /dev/null @@ -1,533 +0,0 @@ -use crate::execution::operator::Operator; -use crate::segment::record_segment::RecordSegmentReader; -use crate::segment::{materialize_logs, LogMaterializerError}; -use async_trait::async_trait; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_distance::normalize; -use chroma_distance::DistanceFunction; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::Chunk; -use chroma_types::{LogRecord, MaterializedLogOperation, Segment}; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use std::sync::Arc; -use thiserror::Error; -use tracing::Instrument; -use tracing::Span; - -/// The brute force k-nearest neighbors operator is responsible for computing the k-nearest neighbors -/// of a given query vector against a set of vectors using brute force calculation. -/// # Note -/// - Callers should ensure that the input vectors are normalized if using the cosine similarity metric. -#[derive(Debug)] -pub struct BruteForceKnnOperator {} - -/// The input to the brute force k-nearest neighbors operator. -/// # Parameters -/// * `data` - The vectors to query against. -/// * `query` - The query vector. -/// * `k` - The number of nearest neighbors to find. -/// * `distance_metric` - The distance metric to use. -#[derive(Debug)] -pub struct BruteForceKnnOperatorInput { - pub log: Chunk, - pub query: Vec, - pub k: usize, - pub distance_metric: DistanceFunction, - pub allowed_ids: Arc<[String]>, - // Deps to create the log materializer - pub record_segment_definition: Segment, - pub blockfile_provider: BlockfileProvider, -} - -/// The output of the brute force k-nearest neighbors operator. -/// # Parameters -/// * `user_ids` - The user ids of the nearest neighbors. -/// * `embeddings` - The embeddings of the nearest neighbors. -/// * `distances` - The distances of the nearest neighbors. -/// One row for each query vector. -#[derive(Debug)] -pub struct BruteForceKnnOperatorOutput { - pub user_ids: Vec, - pub embeddings: Vec>, - pub distances: Vec, -} - -#[derive(Debug)] -struct Entry<'record> { - user_id: &'record str, - embedding: &'record [f32], - distance: f32, -} - -impl Ord for Entry<'_> { - fn cmp(&self, other: &Self) -> Ordering { - if self.distance == other.distance { - Ordering::Equal - } else if self.distance > other.distance { - // This is a min heap, so we need to reverse the ordering. - Ordering::Less - } else { - // This is a min heap, so we need to reverse the ordering. - Ordering::Greater - } - } -} - -impl PartialOrd for Entry<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for Entry<'_> { - fn eq(&self, other: &Self) -> bool { - self.distance == other.distance - } -} - -impl Eq for Entry<'_> {} - -#[derive(Debug, Error)] -pub enum BruteForceKnnOperatorError { - #[error(transparent)] - RecordSegmentReaderCreationError( - #[from] crate::segment::record_segment::RecordSegmentReaderCreationError, - ), - #[error("Error while materializing log records: {0}")] - LogMaterializationError(#[from] LogMaterializerError), -} - -impl ChromaError for BruteForceKnnOperatorError { - fn code(&self) -> ErrorCodes { - match self { - BruteForceKnnOperatorError::RecordSegmentReaderCreationError(e) => e.code(), - BruteForceKnnOperatorError::LogMaterializationError(e) => e.code(), - } - } -} - -#[async_trait] -impl Operator for BruteForceKnnOperator { - type Error = BruteForceKnnOperatorError; - - fn get_name(&self) -> &'static str { - "BruteForceKnnOperator" - } - - async fn run( - &self, - input: &BruteForceKnnOperatorInput, - ) -> Result { - // Materialize the log records - let record_segment_reader = match RecordSegmentReader::from_segment( - &input.record_segment_definition, - &input.blockfile_provider, - ) - .await { - Ok(reader) => Some(reader), - Err(e) => { - match *e { - crate::segment::record_segment::RecordSegmentReaderCreationError::UninitializedSegment => None, - _ => return Err(BruteForceKnnOperatorError::RecordSegmentReaderCreationError(*e)) - } - } - }; - let logs = match materialize_logs(&record_segment_reader, &input.log, None) - .instrument(tracing::trace_span!(parent: Span::current(), "Materialize logs")) - .await - { - Ok(logs) => logs, - Err(e) => { - return Err(BruteForceKnnOperatorError::LogMaterializationError(e)); - } - }; - - let should_normalize = matches!(input.distance_metric, DistanceFunction::Cosine); - let normalized_query = match should_normalize { - true => Some(normalize(&input.query)), - false => None, - }; - - let mut heap = BinaryHeap::with_capacity(input.k); - let data_chunk = logs; - for data in data_chunk.iter() { - let log_record = data.0; - - if log_record.final_operation == MaterializedLogOperation::DeleteExisting { - // Explicitly skip deleted records. - continue; - } - - // Skip records that are disallowed. If allowed list is empty then - // don't exclude anything. - // Empty allowed list is passed when where filtering is absent. - // TODO: This should not need to use merged_user_id, which clones the id. - if !input.allowed_ids.is_empty() - && !input.allowed_ids.contains(&log_record.merged_user_id()) - { - continue; - } - let embedding = &log_record.merged_embeddings(); - if should_normalize { - let normalized_query = normalized_query.as_ref().expect("Invariant violation. Should have set normalized query if should_normalize is true."); - let normalized_embedding = normalize(&embedding[..]); - let distance = input - .distance_metric - .distance(&normalized_embedding[..], &normalized_query[..]); - heap.push(Entry { - user_id: log_record.merged_user_id_ref(), - embedding, - distance, - }); - } else { - let distance = input.distance_metric.distance(&embedding[..], &input.query); - heap.push(Entry { - user_id: log_record.merged_user_id_ref(), - embedding, - distance, - }); - } - } - - let mut sorted_embeddings = Vec::with_capacity(input.k); - let mut sorted_distances = Vec::with_capacity(input.k); - let mut sorted_user_ids = Vec::with_capacity(input.k); - let mut i = 0; - while i < input.k { - let entry = match heap.pop() { - Some(entry) => entry, - None => { - break; - } - }; - sorted_user_ids.push(entry.user_id.to_string()); - sorted_embeddings.push(entry.embedding.to_vec()); - sorted_distances.push(entry.distance); - i += 1; - } - - tracing::info!("Brute force Knn result. distances: {:?}", sorted_distances); - Ok(BruteForceKnnOperatorOutput { - user_ids: sorted_user_ids, - embeddings: sorted_embeddings, - distances: sorted_distances, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use chroma_types::{CollectionUuid, LogRecord, Operation, OperationRecord, SegmentUuid}; - use std::collections::HashMap; - use uuid::uuid; - - // Helper for tests - fn get_blockfile_provider_and_record_segment_definition() -> (BlockfileProvider, Segment) { - // Create a blockfile provider for the log materializer - let blockfile_provider = BlockfileProvider::new_memory(); - - // Create an empty record segment definition - let record_segment_definition = Segment { - id: SegmentUuid(uuid!("00000000-0000-0000-0000-000000000000")), - r#type: chroma_types::SegmentType::BlockfileRecord, - scope: chroma_types::SegmentScope::RECORD, - collection: CollectionUuid(uuid!("00000000-0000-0000-0000-000000000000")), - metadata: None, - file_path: HashMap::new(), - }; - (blockfile_provider, record_segment_definition) - } - - #[tokio::test] - async fn test_brute_force_knn_l2sqr() { - let operator = BruteForceKnnOperator {}; - let (blockfile_provider, record_segment_definition) = - get_blockfile_provider_and_record_segment_definition(); - let data = vec![ - LogRecord { - log_offset: 1, - record: OperationRecord { - id: "embedding_id_1".to_string(), - embedding: Some(vec![0.0, 0.0, 0.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 2, - record: OperationRecord { - id: "embedding_id_2".to_string(), - embedding: Some(vec![0.0, 1.0, 1.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 3, - record: OperationRecord { - id: "embedding_id_3".to_string(), - embedding: Some(vec![7.0, 8.0, 9.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - ]; - let data_chunk = Chunk::new(data.into()); - - let input = BruteForceKnnOperatorInput { - log: data_chunk, - query: vec![0.0, 0.0, 0.0], - k: 2, - distance_metric: DistanceFunction::Euclidean, - allowed_ids: Arc::new([]), - blockfile_provider, - record_segment_definition, - }; - - let output = operator.run(&input).await.unwrap(); - assert_eq!(output.user_ids, vec!["embedding_id_1", "embedding_id_2"]); - let distance_1 = 0.0_f32.powi(2) + 1.0_f32.powi(2) + 1.0_f32.powi(2); - assert_eq!(output.distances, vec![0.0, distance_1]); - assert_eq!( - output.embeddings, - vec![vec![0.0, 0.0, 0.0], vec![0.0, 1.0, 1.0]] - ); - } - - #[tokio::test] - async fn test_brute_force_knn_cosine() { - let operator = BruteForceKnnOperator {}; - let (blockfile_provider, record_segment_definition) = - get_blockfile_provider_and_record_segment_definition(); - - let norm_1 = (1.0_f32.powi(2) + 2.0_f32.powi(2) + 3.0_f32.powi(2)).sqrt(); - let data_1 = vec![1.0 / norm_1, 2.0 / norm_1, 3.0 / norm_1]; - - let norm_2 = (0.0_f32.powi(2) + -(1.0_f32.powi(2)) + 6.0_f32.powi(2)).sqrt(); - let data_2 = vec![0.0 / norm_2, -1.0 / norm_2, 6.0 / norm_2]; - let data = vec![ - LogRecord { - log_offset: 1, - record: OperationRecord { - id: "embedding_id_1".to_string(), - embedding: Some(vec![0.0, 1.0, 0.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 2, - record: OperationRecord { - id: "embedding_id_2".to_string(), - embedding: Some(data_1.clone()), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 3, - record: OperationRecord { - id: "embedding_id_3".to_string(), - embedding: Some(data_2.clone()), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - ]; - let data_chunk = Chunk::new(data.into()); - - let input = BruteForceKnnOperatorInput { - log: data_chunk, - query: vec![0.0, 1.0, 0.0], - k: 2, - distance_metric: DistanceFunction::InnerProduct, - allowed_ids: Arc::new([]), - blockfile_provider, - record_segment_definition, - }; - let output = operator.run(&input).await.unwrap(); - - assert_eq!(output.user_ids, vec!["embedding_id_1", "embedding_id_2"]); - let expected_distance_1 = 1.0 - ((data_1[0] * 0.0) + (data_1[1] * 1.0) + (data_1[2] * 0.0)); - assert_eq!(output.distances, vec![0.0, expected_distance_1]); - assert_eq!( - output.embeddings, - vec![ - vec![0.0, 1.0, 0.0], - vec![1.0 / norm_1, 2.0 / norm_1, 3.0 / norm_1] - ] - ); - } - - #[tokio::test] - async fn test_data_less_than_k() { - let (blockfile_provider, record_segment_definition) = - get_blockfile_provider_and_record_segment_definition(); - - // If we have less data than k, we should return all the data, sorted by distance. - let operator = BruteForceKnnOperator {}; - let data = vec![LogRecord { - log_offset: 1, - record: OperationRecord { - id: "embedding_id_1".to_string(), - embedding: Some(vec![0.0, 0.0, 0.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }]; - - let data_chunk = Chunk::new(data.into()); - - let input = BruteForceKnnOperatorInput { - log: data_chunk, - query: vec![0.0, 0.0, 0.0], - k: 2, - distance_metric: DistanceFunction::Euclidean, - allowed_ids: Arc::new([]), - blockfile_provider, - record_segment_definition, - }; - let output = operator.run(&input).await.unwrap(); - - assert_eq!(output.user_ids, vec!["embedding_id_1"]); - assert_eq!(output.distances, vec![0.0]); - assert_eq!(output.embeddings, vec![vec![0.0, 0.0, 0.0]]); - } - - #[tokio::test] - async fn test_malformed_record_errors() { - let operator = BruteForceKnnOperator {}; - let (blockfile_provider, record_segment_definition) = - get_blockfile_provider_and_record_segment_definition(); - let data = vec![ - LogRecord { - log_offset: 1, - record: OperationRecord { - id: "embedding_id_1".to_string(), - embedding: Some(vec![7.0, 8.0, 9.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 2, - record: OperationRecord { - id: "embedding_id_2".to_string(), - embedding: None, - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 3, - record: OperationRecord { - id: "embedding_id_3".to_string(), - embedding: Some(vec![7.0, 8.0, 9.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - ]; - let data_chunk = Chunk::new(data.into()); - - let input = BruteForceKnnOperatorInput { - log: data_chunk, - query: vec![0.0, 0.0, 0.0], - k: 2, - distance_metric: DistanceFunction::Euclidean, - allowed_ids: Arc::new([]), - blockfile_provider, - record_segment_definition, - }; - let res = operator.run(&input).await; - match res { - Ok(_) => panic!("Expected error"), - Err(e) => match e { - BruteForceKnnOperatorError::LogMaterializationError(_) => { - // We expect an error here because the second record is malformed. - } - _ => panic!("Unexpected error"), - }, - } - } - - #[tokio::test] - async fn test_skip_deleted_record() { - let operator = BruteForceKnnOperator {}; - let (blockfile_provider, record_segment_definition) = - get_blockfile_provider_and_record_segment_definition(); - let data = vec![ - LogRecord { - log_offset: 1, - record: OperationRecord { - id: "embedding_id_1".to_string(), - embedding: Some(vec![0.0, 0.0, 0.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - LogRecord { - log_offset: 2, - record: OperationRecord { - id: "embedding_id_1".to_string(), - embedding: None, - encoding: None, - metadata: None, - document: None, - operation: Operation::Delete, - }, - }, - LogRecord { - log_offset: 3, - record: OperationRecord { - id: "embedding_id_3".to_string(), - embedding: Some(vec![0.0, 0.0, 0.0]), - encoding: None, - metadata: None, - document: None, - operation: Operation::Add, - }, - }, - ]; - let data_chunk = Chunk::new(data.into()); - - let input = BruteForceKnnOperatorInput { - log: data_chunk, - query: vec![0.0, 0.0, 0.0], - k: 2, - distance_metric: DistanceFunction::Euclidean, - allowed_ids: Arc::new([]), - blockfile_provider, - record_segment_definition, - }; - let output = operator.run(&input).await.unwrap(); - - assert_eq!(output.user_ids, vec!["embedding_id_3"]); - assert_eq!(output.distances, vec![0.0]); - assert_eq!(output.embeddings, vec![vec![0.0, 0.0, 0.0]]); - } -} diff --git a/rust/worker/src/execution/operators/get_vectors_operator.rs b/rust/worker/src/execution/operators/get_vectors_operator.rs deleted file mode 100644 index fb06e267c9d..00000000000 --- a/rust/worker/src/execution/operators/get_vectors_operator.rs +++ /dev/null @@ -1,179 +0,0 @@ -use crate::{ - execution::operator::Operator, - segment::{ - materialize_logs, - record_segment::{self, RecordSegmentReader}, - LogMaterializerError, - }, -}; -use async_trait::async_trait; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Chunk, LogRecord, MaterializedLogOperation, Segment}; -use std::collections::{HashMap, HashSet}; -use thiserror::Error; - -#[derive(Debug)] -pub struct GetVectorsOperator {} - -impl GetVectorsOperator { - pub fn new() -> Box { - Box::new(GetVectorsOperator {}) - } -} - -/// The input to the get vectors operator. -/// # Parameters -/// * `record_segment_definition` - The segment definition for the record segment. -/// * `blockfile_provider` - The blockfile provider. -/// * `log_records` - The log records. -/// * `search_user_ids` - The user ids to search for. -#[derive(Debug)] -pub struct GetVectorsOperatorInput { - record_segment_definition: Segment, - blockfile_provider: BlockfileProvider, - log_records: Chunk, - search_user_ids: Vec, -} - -impl GetVectorsOperatorInput { - pub fn new( - record_segment_definition: Segment, - blockfile_provider: BlockfileProvider, - log_records: Chunk, - search_user_ids: Vec, - ) -> Self { - GetVectorsOperatorInput { - record_segment_definition, - blockfile_provider, - log_records, - search_user_ids, - } - } -} - -/// The output of the get vectors operator. -/// # Parameters -/// * `ids` - The ids of the vectors. -/// * `vectors` - The vectors. -/// # Notes -/// The vectors are in the same order as the ids. -#[derive(Debug)] -pub struct GetVectorsOperatorOutput { - pub(crate) ids: Vec, - pub(crate) vectors: Vec>, -} - -#[derive(Debug, Error)] -pub enum GetVectorsOperatorError { - #[error("Error creating record segment reader {0}")] - RecordSegmentReaderCreation( - #[from] crate::segment::record_segment::RecordSegmentReaderCreationError, - ), - #[error(transparent)] - RecordSegmentReader(#[from] Box), - #[error("Error materializing logs {0}")] - LogMaterialization(#[from] LogMaterializerError), -} - -impl ChromaError for GetVectorsOperatorError { - fn code(&self) -> ErrorCodes { - ErrorCodes::Internal - } -} - -#[async_trait] -impl Operator for GetVectorsOperator { - type Error = GetVectorsOperatorError; - - fn get_name(&self) -> &'static str { - "GetVectorsOperator" - } - - async fn run( - &self, - input: &GetVectorsOperatorInput, - ) -> Result { - let mut output_vectors = HashMap::new(); - - // Materialize logs. - let record_segment_reader = match RecordSegmentReader::from_segment( - &input.record_segment_definition, - &input.blockfile_provider, - ) - .await - { - Ok(reader) => Some(reader), - Err(e) => match *e { - record_segment::RecordSegmentReaderCreationError::UninitializedSegment => None, - record_segment::RecordSegmentReaderCreationError::BlockfileOpenError(_) => { - return Err(GetVectorsOperatorError::RecordSegmentReaderCreation(*e)) - } - record_segment::RecordSegmentReaderCreationError::InvalidNumberOfFiles => { - return Err(GetVectorsOperatorError::RecordSegmentReaderCreation(*e)) - } - record_segment::RecordSegmentReaderCreationError::DataRecordNotFound(_) => { - return Err(GetVectorsOperatorError::RecordSegmentReaderCreation(*e)) - } - record_segment::RecordSegmentReaderCreationError::UserRecordNotFound(_) => { - return Err(GetVectorsOperatorError::RecordSegmentReaderCreation(*e)) - } - }, - }; - // Step 1: Materialize the logs. - let mat_records = - match materialize_logs(&record_segment_reader, &input.log_records, None).await { - Ok(records) => records, - Err(e) => { - return Err(GetVectorsOperatorError::LogMaterialization(e)); - } - }; - - // Search the log records for the user ids - let mut remaining_search_user_ids: HashSet = - HashSet::from_iter(input.search_user_ids.iter().cloned()); - for (log_record, _) in mat_records.iter() { - // Log is the source of truth for these so don't consider these for - // reading from the segment. - let mut removed = false; - if remaining_search_user_ids.contains(log_record.merged_user_id_ref()) { - removed = true; - remaining_search_user_ids.remove(log_record.merged_user_id_ref()); - } - if removed && log_record.final_operation != MaterializedLogOperation::DeleteExisting { - output_vectors.insert( - log_record.merged_user_id(), - log_record.merged_embeddings().to_vec(), - ); - } - } - - // Search the record segment for the remaining user ids - if !remaining_search_user_ids.is_empty() { - if let Some(reader) = record_segment_reader { - for user_id in remaining_search_user_ids.iter() { - let read_data = reader.get_data_and_offset_id_for_user_id(user_id).await; - match read_data { - Ok(Some((record, _))) => { - output_vectors.insert(record.id.to_string(), record.embedding.to_vec()); - } - Ok(None) => {} - Err(_) => { - // If the user id is not found in the record segment, we do not add it to the output - } - } - } - } - } - - let mut ids = Vec::new(); - let mut vectors = Vec::new(); - for id in &input.search_user_ids { - if output_vectors.contains_key(id) { - ids.push(id.clone()); - vectors.push(output_vectors.remove(id).unwrap()); - } - } - return Ok(GetVectorsOperatorOutput { ids, vectors }); - } -} diff --git a/rust/worker/src/execution/operators/hnsw_knn.rs b/rust/worker/src/execution/operators/hnsw_knn.rs deleted file mode 100644 index c5e57467d5d..00000000000 --- a/rust/worker/src/execution/operators/hnsw_knn.rs +++ /dev/null @@ -1,248 +0,0 @@ -use crate::segment::record_segment::RecordSegmentReaderCreationError; -use crate::segment::{materialize_logs, LogMaterializerError, MaterializedLogRecord}; -use crate::{ - execution::operator::Operator, - segment::{ - distributed_hnsw_segment::DistributedHNSWSegmentReader, record_segment::RecordSegmentReader, - }, -}; -use async_trait::async_trait; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::Segment; -use chroma_types::{Chunk, LogRecord, MaterializedLogOperation}; -use std::collections::HashSet; -use std::sync::Arc; -use thiserror::Error; -use tracing::{Instrument, Span}; - -#[derive(Debug)] -pub struct HnswKnnOperator {} - -#[derive(Debug)] -pub struct HnswKnnOperatorInput { - pub segment: Box, - pub query: Vec, - pub k: usize, - pub record_segment: Segment, - pub blockfile_provider: BlockfileProvider, - pub allowed_ids: Arc<[String]>, - pub logs: Chunk, -} - -#[derive(Debug)] -pub struct HnswKnnOperatorOutput { - pub offset_ids: Vec, - pub distances: Vec, -} - -#[derive(Error, Debug)] -pub enum HnswKnnOperatorError { - #[error("Error creating Record Segment")] - RecordSegmentError, - #[error("Error reading Record Segment")] - RecordSegmentReadError, - #[error("Invalid allowed and disallowed ids")] - InvalidAllowedAndDisallowedIds, - #[error("Error materializing logs {0}")] - LogMaterializationError(#[from] LogMaterializerError), - #[error("Error querying HNSW {0}")] - QueryError(#[from] Box), -} - -impl ChromaError for HnswKnnOperatorError { - fn code(&self) -> ErrorCodes { - match self { - HnswKnnOperatorError::RecordSegmentError => ErrorCodes::Internal, - HnswKnnOperatorError::RecordSegmentReadError => ErrorCodes::Internal, - HnswKnnOperatorError::InvalidAllowedAndDisallowedIds => ErrorCodes::InvalidArgument, - HnswKnnOperatorError::LogMaterializationError(e) => e.code(), - HnswKnnOperatorError::QueryError(e) => e.code(), - } - } -} - -impl HnswKnnOperator { - async fn get_disallowed_ids<'referred_data>( - &self, - logs: Chunk>, - record_segment_reader: &RecordSegmentReader<'_>, - ) -> Result, Box> { - let mut disallowed_ids = Vec::new(); - for item in logs.iter() { - let log = item.0; - // This means that even if an embedding is not updated on the log, - // we brute force it. Can use the HNSW index also. - if log.final_operation == MaterializedLogOperation::DeleteExisting - || log.final_operation == MaterializedLogOperation::UpdateExisting - || log.final_operation == MaterializedLogOperation::OverwriteExisting - { - let offset_id = record_segment_reader - .get_offset_id_for_user_id(log.merged_user_id_ref()) - .await; - match offset_id { - Ok(Some(offset_id)) => disallowed_ids.push(offset_id), - Ok(None) => { - return Err(Box::new(HnswKnnOperatorError::RecordSegmentReadError)); - } - Err(e) => { - return Err(e); - } - } - } - } - Ok(disallowed_ids) - } - - // Validate that the allowed ids are not in the disallowed ids - fn validate_allowed_and_disallowed_ids( - &self, - allowed_ids: &[u32], - disallowed_ids: &[u32], - ) -> Result<(), Box> { - for allowed_id in allowed_ids { - if disallowed_ids.contains(allowed_id) { - return Err(Box::new( - HnswKnnOperatorError::InvalidAllowedAndDisallowedIds, - )); - } - } - Ok(()) - } -} - -#[async_trait] -impl Operator for HnswKnnOperator { - type Error = Box; - - fn get_name(&self) -> &'static str { - "HnswKnnOperator" - } - - async fn run( - &self, - input: &HnswKnnOperatorInput, - ) -> Result { - let record_segment_reader = match RecordSegmentReader::from_segment( - &input.record_segment, - &input.blockfile_provider, - ) - .await - { - Ok(reader) => reader, - Err(e) => match *e { - RecordSegmentReaderCreationError::UninitializedSegment => { - tracing::error!( - "[HnswKnnOperation]: Error creating record segment reader {:?}", - *e - ); - return Ok(HnswKnnOperatorOutput { - offset_ids: vec![], - distances: vec![], - }); - } - _ => { - tracing::error!("[HnswKnnOperation]: Error creating record segment {:?}", e); - return Err(Box::new(HnswKnnOperatorError::RecordSegmentError)); - } - }, - }; - let some_reader = Some(record_segment_reader.clone()); - let logs = match materialize_logs(&some_reader, &input.logs, None) - .instrument(tracing::trace_span!(parent: Span::current(), "Materialize logs")) - .await - { - Ok(logs) => logs, - Err(e) => { - tracing::error!("[HnswKnnOperation]: Error materializing logs {:?}", e); - return Err(Box::new(HnswKnnOperatorError::LogMaterializationError(e))); - } - }; - let mut remaining_allowed_ids: HashSet<&String> = - HashSet::from_iter(input.allowed_ids.iter()); - for (log, _) in logs.iter() { - #[allow(clippy::unnecessary_to_owned)] - remaining_allowed_ids.remove(&log.merged_user_id_ref().to_string()); - } - // If a filter list is supplied but it does not have anything for the segment, as it implies the data is all in the log - // then return an empty response. - if !input.allowed_ids.is_empty() && remaining_allowed_ids.is_empty() { - return Ok(HnswKnnOperatorOutput { - offset_ids: vec![], - distances: vec![], - }); - } - let mut allowed_offset_ids = Vec::new(); - for user_id in remaining_allowed_ids { - let offset_id = record_segment_reader - .get_offset_id_for_user_id(user_id) - .await; - match offset_id { - Ok(Some(offset_id)) => allowed_offset_ids.push(offset_id), - Ok(None) => { - return Err(Box::new(HnswKnnOperatorError::RecordSegmentReadError)); - } - Err(e) => { - tracing::error!( - "[HnswKnnOperation]: Record segment read error for allowed ids {:?}", - e - ); - return Err(Box::new(HnswKnnOperatorError::RecordSegmentReadError)); - } - } - } - tracing::info!( - "[HnswKnnOperation]: Allowed {} offset ids", - allowed_offset_ids.len() - ); - let disallowed_offset_ids = - match self.get_disallowed_ids(logs, &record_segment_reader).await { - Ok(disallowed_offset_ids) => disallowed_offset_ids, - Err(e) => { - tracing::error!("[HnswKnnOperation]: Error fetching disallowed ids {:?}", e); - return Err(Box::new(HnswKnnOperatorError::RecordSegmentReadError)); - } - }; - tracing::info!( - "[HnswKnnOperation]: Disallowed {} offset ids", - disallowed_offset_ids.len() - ); - - match self.validate_allowed_and_disallowed_ids(&allowed_offset_ids, &disallowed_offset_ids) - { - Ok(_) => {} - Err(e) => { - tracing::error!( - "[HnswKnnOperation]: Error validating allowed and disallowed ids {:?}", - e - ); - return Err(e); - } - }; - - // Convert to usize - let allowed_offset_ids: Vec = - allowed_offset_ids.iter().map(|&x| x as usize).collect(); - let disallowed_offset_ids: Vec = - disallowed_offset_ids.iter().map(|&x| x as usize).collect(); - - let query_results = input.segment.query( - &input.query, - input.k, - &allowed_offset_ids, - &disallowed_offset_ids, - ); - let (offset_ids, distances) = match query_results { - Ok(results) => results, - Err(e) => { - tracing::error!("[HnswKnnOperation]: Error querying HNSW {:?}", e); - return Err(Box::new(HnswKnnOperatorError::QueryError(e))); - } - }; - - Ok(HnswKnnOperatorOutput { - offset_ids, - distances, - }) - } -} diff --git a/rust/worker/src/execution/operators/merge_knn_results.rs b/rust/worker/src/execution/operators/merge_knn_results.rs deleted file mode 100644 index 25f0e1dcb41..00000000000 --- a/rust/worker/src/execution/operators/merge_knn_results.rs +++ /dev/null @@ -1,289 +0,0 @@ -use crate::{ - execution::operator::Operator, - segment::record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError}, -}; -use async_trait::async_trait; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::Segment; -use thiserror::Error; - -#[derive(Debug)] -pub struct MergeKnnResultsOperator {} - -#[derive(Debug)] -pub struct MergeKnnBruteForceResultInput { - pub user_ids: Vec, - pub distances: Vec, - pub vectors: Vec>, -} - -#[derive(Debug)] -pub struct MergeKnnResultsOperatorInput { - hnsw_result_offset_ids: Vec, - hnsw_result_distances: Vec, - brute_force_result: Option, - include_vectors: bool, - k: usize, - record_segment_definition: Segment, - blockfile_provider: BlockfileProvider, -} - -#[allow(dead_code)] -impl MergeKnnResultsOperatorInput { - pub fn new( - hnsw_result_offset_ids: Vec, - hnsw_result_distances: Vec, - brute_force_result: Option, - include_vectors: bool, - k: usize, - record_segment_definition: Segment, - blockfile_provider: BlockfileProvider, - ) -> Self { - Self { - hnsw_result_offset_ids, - hnsw_result_distances, - brute_force_result, - include_vectors, - k, - record_segment_definition, - blockfile_provider, - } - } -} - -#[derive(Debug)] -#[allow(dead_code)] -pub struct MergeKnnResultsOperatorOutput { - pub user_ids: Vec, - pub distances: Vec, - pub vectors: Option>>, -} - -#[derive(Error, Debug)] -pub enum MergeKnnResultsOperatorError {} - -impl ChromaError for MergeKnnResultsOperatorError { - fn code(&self) -> ErrorCodes { - ErrorCodes::Unknown - } -} - -#[async_trait] -impl Operator - for MergeKnnResultsOperator -{ - type Error = Box; - - fn get_name(&self) -> &'static str { - "MergeKnnResultsOperator" - } - - async fn run( - &self, - input: &MergeKnnResultsOperatorInput, - ) -> Result { - let (result_user_ids, result_distances, result_vectors) = - match RecordSegmentReader::from_segment( - &input.record_segment_definition, - &input.blockfile_provider, - ) - .await - { - Ok(reader) => { - // Convert the HNSW result offset IDs to user IDs - let mut hnsw_result_user_ids = Vec::new(); - let mut hnsw_result_vectors = None; - if input.include_vectors { - hnsw_result_vectors = Some(Vec::new()); - } - for offset_id in &input.hnsw_result_offset_ids { - let user_id = reader.get_user_id_for_offset_id(*offset_id as u32).await; - match user_id { - Ok(user_id) => hnsw_result_user_ids.push(user_id), - Err(e) => return Err(e), - } - if let Some(hnsw_result_vectors) = &mut hnsw_result_vectors { - let record = reader.get_data_for_offset_id(*offset_id as u32).await; - match record { - Ok(Some(record)) => { - hnsw_result_vectors.push(record.embedding.to_vec()) - } - Ok(None) => { - return Err(Box::new( - RecordSegmentReaderCreationError::DataRecordNotFound( - *offset_id as u32, - ), - )); - } - Err(e) => return Err(e), - } - } - } - - match &input.brute_force_result { - Some(brute_force_result) => merge_results( - &hnsw_result_user_ids, - &input.hnsw_result_distances, - &hnsw_result_vectors, - &brute_force_result.user_ids, - &brute_force_result.distances, - &brute_force_result.vectors, - input.include_vectors, - input.k, - ), - None => { - // There are no brute force results - ( - hnsw_result_user_ids - .iter() - .map(|x| x.to_string()) - .collect::>(), - input.hnsw_result_distances.clone(), - hnsw_result_vectors, - ) - } - } - } - Err(e) => match *e { - RecordSegmentReaderCreationError::BlockfileOpenError(e) => { - return Err(e); - } - RecordSegmentReaderCreationError::InvalidNumberOfFiles => { - return Err(e); - } - RecordSegmentReaderCreationError::DataRecordNotFound(_) => { - return Err(e); - } - RecordSegmentReaderCreationError::UserRecordNotFound(_) => { - return Err(e); - } - RecordSegmentReaderCreationError::UninitializedSegment => { - // The record segment doesn't exist - which implies no HNSW results - let hnsw_result_user_ids = Vec::new(); - let hnsw_result_distances = Vec::new(); - let hnsw_result_vectors = None; - - match &input.brute_force_result { - Some(brute_force_result) => merge_results( - &hnsw_result_user_ids, - &hnsw_result_distances, - &hnsw_result_vectors, - &brute_force_result.user_ids, - &brute_force_result.distances, - &brute_force_result.vectors, - input.include_vectors, - input.k, - ), - None => { - // There are no HNSW results and no brute force results - ( - Vec::new(), - Vec::new(), - if input.include_vectors { - Some(Vec::new()) - } else { - None - }, - ) - } - } - } - }, - }; - - Ok(MergeKnnResultsOperatorOutput { - user_ids: result_user_ids, - distances: result_distances, - vectors: result_vectors, - }) - } -} - -#[allow(clippy::too_many_arguments)] -fn merge_results( - hnsw_result_user_ids: &[&str], - hnsw_result_distances: &[f32], - hnsw_result_vectors: &Option>>, - brute_force_result_user_ids: &[String], - brute_force_result_distances: &[f32], - brute_force_result_vectors: &[Vec], - include_vectors: bool, - k: usize, -) -> (Vec, Vec, Option>>) { - let mut result_user_ids = Vec::with_capacity(k); - let mut result_distances = Vec::with_capacity(k); - let mut result_vectors = None; - if include_vectors { - result_vectors = Some(Vec::with_capacity(k)); - } - - // Merge the HNSW and brute force results together by the minimum distance top k - let mut hnsw_index = 0; - let mut brute_force_index = 0; - - // TODO: This doesn't have to clone the user IDs, but it's easier for now - while (result_user_ids.len() < k) - && (hnsw_index < hnsw_result_user_ids.len() - || brute_force_index < brute_force_result_user_ids.len()) - { - if hnsw_index < hnsw_result_user_ids.len() - && brute_force_index < brute_force_result_user_ids.len() - { - if hnsw_result_distances[hnsw_index] < brute_force_result_distances[brute_force_index] { - result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string()); - result_distances.push(hnsw_result_distances[hnsw_index]); - if include_vectors { - result_vectors - .as_mut() - .expect("Include vectors is true, result_vectors should be Some") - .push( - hnsw_result_vectors.as_ref().expect( - "Include vectors is true, hnsw_result_vectors should be Some", - )[hnsw_index] - .to_vec(), - ); - } - hnsw_index += 1; - } else { - result_user_ids.push(brute_force_result_user_ids[brute_force_index].to_string()); - result_distances.push(brute_force_result_distances[brute_force_index]); - if include_vectors { - result_vectors - .as_mut() - .expect("Include vectors is true, result_vectors should be Some") - .push(brute_force_result_vectors[brute_force_index].to_vec()); - } - brute_force_index += 1; - } - } else if hnsw_index < hnsw_result_user_ids.len() { - result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string()); - result_distances.push(hnsw_result_distances[hnsw_index]); - if include_vectors { - result_vectors - .as_mut() - .expect("Include vectors is true, result_vectors should be Some") - .push( - hnsw_result_vectors - .as_ref() - .expect("Include vectors is true, hnsw_result_vectors should be Some") - [hnsw_index] - .to_vec(), - ); - } - hnsw_index += 1; - } else if brute_force_index < brute_force_result_user_ids.len() { - result_user_ids.push(brute_force_result_user_ids[brute_force_index].to_string()); - result_distances.push(brute_force_result_distances[brute_force_index]); - if include_vectors { - result_vectors - .as_mut() - .expect("Include vectors is true, result_vectors should be Some") - .push(brute_force_result_vectors[brute_force_index].to_vec()); - } - brute_force_index += 1; - } - } - - (result_user_ids, result_distances, result_vectors) -} diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index c01db1a7b60..c7bc25660c9 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -1,13 +1,7 @@ -pub(super) mod brute_force_knn; pub(super) mod count_records; pub(super) mod flush_s3; -pub(super) mod get_vectors_operator; -pub(super) mod hnsw_knn; -pub(super) mod merge_knn_results; -pub(super) mod normalize_vectors; pub(super) mod partition; pub(super) mod pull_log; -pub(super) mod record_segment_prefetch; pub(super) mod register; pub mod spann_bf_pl; pub(super) mod spann_centers_search; diff --git a/rust/worker/src/execution/operators/normalize_vectors.rs b/rust/worker/src/execution/operators/normalize_vectors.rs deleted file mode 100644 index 03e884aa0d3..00000000000 --- a/rust/worker/src/execution/operators/normalize_vectors.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::execution::operator::Operator; -use async_trait::async_trait; -use chroma_distance::normalize; - -#[derive(Debug)] -pub struct NormalizeVectorOperator {} - -pub struct NormalizeVectorOperatorInput { - pub vectors: Vec>, -} - -pub struct NormalizeVectorOperatorOutput { - pub _normalized_vectors: Vec>, -} - -#[async_trait] -impl Operator - for NormalizeVectorOperator -{ - type Error = (); - - fn get_name(&self) -> &'static str { - "NormalizeVectorOperator" - } - - async fn run( - &self, - input: &NormalizeVectorOperatorInput, - ) -> Result { - // TODO: this should not have to reallocate the vectors. We can optimize this later. - let mut normalized_vectors = Vec::with_capacity(input.vectors.len()); - for vector in &input.vectors { - let normalized_vector = normalize(vector); - normalized_vectors.push(normalized_vector); - } - Ok(NormalizeVectorOperatorOutput { - _normalized_vectors: normalized_vectors, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - const COMPARE_EPS: f32 = 1e-9; - fn float_eps_eq(a: &[f32], b: &[f32]) -> bool { - a.iter() - .zip(b.iter()) - .all(|(a, b)| (a - b).abs() < COMPARE_EPS) - } - - #[tokio::test] - async fn test_normalize_vector() { - let operator = NormalizeVectorOperator {}; - let input = NormalizeVectorOperatorInput { - vectors: vec![ - vec![1.0, 2.0, 3.0], - vec![4.0, 5.0, 6.0], - vec![7.0, 8.0, 9.0], - ], - }; - - let output = operator.run(&input).await.unwrap(); - let expected_output = NormalizeVectorOperatorOutput { - _normalized_vectors: vec![ - vec![0.26726124, 0.5345225, 0.8017837], - vec![0.45584232, 0.5698029, 0.68376344], - vec![0.5025707, 0.5743665, 0.64616233], - ], - }; - - for (a, b) in output - ._normalized_vectors - .iter() - .zip(expected_output._normalized_vectors.iter()) - { - assert!(float_eps_eq(a, b), "{:?} != {:?}", a, b); - } - } -} diff --git a/rust/worker/src/execution/operators/record_segment_prefetch.rs b/rust/worker/src/execution/operators/record_segment_prefetch.rs deleted file mode 100644 index 0907fffbd0a..00000000000 --- a/rust/worker/src/execution/operators/record_segment_prefetch.rs +++ /dev/null @@ -1,128 +0,0 @@ -use crate::{ - execution::operator::{Operator, OperatorType}, - segment::record_segment::RecordSegmentReader, -}; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::Segment; -use thiserror::Error; -use tonic::async_trait; - -#[derive(Debug)] -pub(crate) struct OffsetIdToDataKeys { - pub(crate) keys: Vec, -} - -#[derive(Debug)] -pub(crate) struct OffsetIdToUserIdKeys { - pub(crate) keys: Vec, -} - -#[derive(Debug)] -#[allow(dead_code)] -pub(crate) enum Keys { - OffsetIdToDataKeys(OffsetIdToDataKeys), - OffsetIdToUserIdKeys(OffsetIdToUserIdKeys), -} - -#[derive(Debug)] -pub(crate) struct RecordSegmentPrefetchIoInput { - pub(crate) keys: Keys, - pub(crate) segment: Segment, - pub(crate) provider: BlockfileProvider, -} - -#[derive(Debug)] -pub(crate) struct RecordSegmentPrefetchIoOutput { - // This is fire and forget so nothing to return. -} - -#[derive(Debug)] -pub(crate) struct RecordSegmentPrefetchIoOperator {} - -#[allow(dead_code)] -impl RecordSegmentPrefetchIoOperator { - pub fn new() -> Box { - Box::new(RecordSegmentPrefetchIoOperator {}) - } -} - -#[derive(Error, Debug)] -pub(crate) enum RecordSegmentPrefetchIoOperatorError { - #[error("Error creating Record Segment reader")] - RecordSegmentReaderCreationError, -} - -impl ChromaError for RecordSegmentPrefetchIoOperatorError { - fn code(&self) -> ErrorCodes { - match self { - Self::RecordSegmentReaderCreationError => ErrorCodes::Internal, - } - } -} - -#[async_trait] -impl Operator - for RecordSegmentPrefetchIoOperator -{ - type Error = RecordSegmentPrefetchIoOperatorError; - - fn get_name(&self) -> &'static str { - "RecordSegmentPrefetchIoOperator" - } - - async fn run( - &self, - input: &RecordSegmentPrefetchIoInput, - ) -> Result { - match &input.keys { - Keys::OffsetIdToDataKeys(keys) => { - if keys.keys.is_empty() { - return Ok(RecordSegmentPrefetchIoOutput {}); - } - // Construct record segment reader. - let record_segment_reader = match RecordSegmentReader::from_segment( - &input.segment, - &input.provider, - ) - .await - { - Ok(reader) => reader, - Err(_) => { - return Err( - RecordSegmentPrefetchIoOperatorError::RecordSegmentReaderCreationError, - ); - } - }; - record_segment_reader.prefetch_id_to_data(&keys.keys).await; - } - Keys::OffsetIdToUserIdKeys(keys) => { - if keys.keys.is_empty() { - return Ok(RecordSegmentPrefetchIoOutput {}); - } - // Construct record segment reader. - let record_segment_reader = match RecordSegmentReader::from_segment( - &input.segment, - &input.provider, - ) - .await - { - Ok(reader) => reader, - Err(_) => { - return Err( - RecordSegmentPrefetchIoOperatorError::RecordSegmentReaderCreationError, - ); - } - }; - record_segment_reader - .prefetch_id_to_user_id(&keys.keys) - .await; - } - } - Ok(RecordSegmentPrefetchIoOutput {}) - } - - fn get_type(&self) -> OperatorType { - OperatorType::IO - } -} diff --git a/rust/worker/src/execution/orchestration/common.rs b/rust/worker/src/execution/orchestration/common.rs index 82cb13c1835..ff72803f3e8 100644 --- a/rust/worker/src/execution/orchestration/common.rs +++ b/rust/worker/src/execution/orchestration/common.rs @@ -1,158 +1,5 @@ -use crate::{ - sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}, - system::{Component, ComponentContext}, -}; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Collection, CollectionUuid, Segment, SegmentType, SegmentUuid}; -use thiserror::Error; -use tracing::{trace_span, Instrument, Span}; -use uuid::Uuid; - -#[derive(Debug, Error)] -pub(super) enum GetHnswSegmentByIdError { - #[error("Hnsw segment with id: {0} not found")] - HnswSegmentNotFound(Uuid), - #[error("Get segments error: {0}")] - GetSegmentsError(#[from] GetSegmentsError), -} - -impl ChromaError for GetHnswSegmentByIdError { - fn code(&self) -> ErrorCodes { - match self { - GetHnswSegmentByIdError::HnswSegmentNotFound(_) => ErrorCodes::NotFound, - GetHnswSegmentByIdError::GetSegmentsError(e) => e.code(), - } - } -} - -pub(super) async fn get_hnsw_segment_by_id( - mut sysdb: Box, - hnsw_segment_id: &Uuid, - collection_id: &CollectionUuid, -) -> Result> { - let segments = sysdb - .get_segments( - Some(SegmentUuid(*hnsw_segment_id)), - None, - None, - *collection_id, - ) - .await; - let segment = match segments { - Ok(segments) => { - if segments.is_empty() { - return Err(Box::new(GetHnswSegmentByIdError::HnswSegmentNotFound( - *hnsw_segment_id, - ))); - } - segments[0].clone() - } - Err(e) => { - return Err(Box::new(GetHnswSegmentByIdError::GetSegmentsError(e))); - } - }; - - if segment.r#type != SegmentType::HnswDistributed { - return Err(Box::new(GetHnswSegmentByIdError::HnswSegmentNotFound( - *hnsw_segment_id, - ))); - } - Ok(segment) -} - -#[derive(Debug, Error)] -pub(super) enum GetCollectionByIdError { - #[error("Collection with id: {0} not found")] - CollectionNotFound(CollectionUuid), - #[error("Get collection error")] - GetCollectionError(#[from] GetCollectionsError), -} - -impl ChromaError for GetCollectionByIdError { - fn code(&self) -> ErrorCodes { - match self { - GetCollectionByIdError::CollectionNotFound(_) => ErrorCodes::NotFound, - GetCollectionByIdError::GetCollectionError(e) => e.code(), - } - } -} - -pub(super) async fn get_collection_by_id( - mut sysdb: Box, - collection_id: &CollectionUuid, -) -> Result> { - let child_span: tracing::Span = - trace_span!(parent: Span::current(), "get collection for collection id"); - let collections = sysdb - .get_collections(Some(*collection_id), None, None, None) - .instrument(child_span.clone()) - .await; - match collections { - Ok(mut collections) => { - if collections.is_empty() { - return Err(Box::new(GetCollectionByIdError::CollectionNotFound( - *collection_id, - ))); - } - Ok(collections.drain(..).next().unwrap()) - } - Err(e) => Err(Box::new(GetCollectionByIdError::GetCollectionError(e))), - } -} - -#[derive(Debug, Error)] -pub(super) enum GetRecordSegmentByCollectionIdError { - #[error("Record segment for collection with id: {0} not found")] - RecordSegmentNotFound(CollectionUuid), - #[error("Get segments error: {0}")] - GetSegmentsError(#[from] GetSegmentsError), -} - -impl ChromaError for GetRecordSegmentByCollectionIdError { - fn code(&self) -> ErrorCodes { - match self { - GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(_) => ErrorCodes::NotFound, - GetRecordSegmentByCollectionIdError::GetSegmentsError(e) => e.code(), - } - } -} - -pub(super) async fn get_record_segment_by_collection_id( - mut sysdb: Box, - collection_id: &CollectionUuid, -) -> Result> { - let segments = sysdb - .get_segments( - None, - Some(SegmentType::BlockfileRecord.into()), - None, - *collection_id, - ) - .await; - - let segment = match segments { - Ok(mut segments) => { - if segments.is_empty() { - return Err(Box::new( - GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(*collection_id), - )); - } - segments.drain(..).next().unwrap() - } - Err(e) => { - return Err(Box::new( - GetRecordSegmentByCollectionIdError::GetSegmentsError(e), - )); - } - }; - - if segment.r#type != SegmentType::BlockfileRecord { - return Err(Box::new( - GetRecordSegmentByCollectionIdError::RecordSegmentNotFound(*collection_id), - )); - } - Ok(segment) -} +use crate::system::{Component, ComponentContext}; +use chroma_error::ChromaError; /// Terminate the orchestrator with an error /// This function sends an error to the result channel and cancels the orchestrator diff --git a/rust/worker/src/execution/orchestration/get_vectors.rs b/rust/worker/src/execution/orchestration/get_vectors.rs deleted file mode 100644 index 5ed95d42efb..00000000000 --- a/rust/worker/src/execution/orchestration/get_vectors.rs +++ /dev/null @@ -1,348 +0,0 @@ -use super::common::{ - get_collection_by_id, get_hnsw_segment_by_id, get_record_segment_by_collection_id, -}; -use crate::{ - execution::{ - dispatcher::Dispatcher, - operator::{wrap, TaskResult}, - operators::{ - get_vectors_operator::{ - GetVectorsOperator, GetVectorsOperatorError, GetVectorsOperatorInput, - GetVectorsOperatorOutput, - }, - pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput}, - }, - orchestration::common::terminate_with_error, - }, - log::log::{Log, PullLogsError}, - sysdb::sysdb::SysDb, - system::{ - ChannelError, Component, ComponentContext, ComponentHandle, Handler, ReceiverForMessage, - System, - }, -}; -use async_trait::async_trait; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Chunk, Collection, CollectionUuid, GetVectorsResult, LogRecord, Segment}; -use std::time::{SystemTime, UNIX_EPOCH}; -use thiserror::Error; -use tracing::{trace, Span}; -use uuid::Uuid; - -#[derive(Debug)] -#[allow(dead_code)] -enum ExecutionState { - Pending, - PullLogs, - GetVectors, -} - -#[derive(Debug, Error)] -enum GetVectorsError { - #[error("Error sending task to dispatcher")] - TaskSendError(#[from] ChannelError), - #[error("System time error")] - SystemTimeError(#[from] std::time::SystemTimeError), - #[error("Collection version mismatch")] - CollectionVersionMismatch, -} - -impl ChromaError for GetVectorsError { - fn code(&self) -> ErrorCodes { - match self { - GetVectorsError::TaskSendError(e) => e.code(), - GetVectorsError::SystemTimeError(_) => ErrorCodes::Internal, - GetVectorsError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, - } - } -} - -#[derive(Debug)] -#[allow(dead_code)] -pub struct GetVectorsOrchestrator { - state: ExecutionState, - // Component Execution - system: System, - // Query state - search_user_ids: Vec, - hnsw_segment_id: Uuid, - collection_id: CollectionUuid, - // State fetched or created for query execution - record_segment: Option, - collection: Option, - // Services - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - blockfile_provider: BlockfileProvider, - // Result channel - result_channel: - Option>>>, - collection_version: u32, - log_position: u64, -} - -#[allow(dead_code)] -impl GetVectorsOrchestrator { - #[allow(clippy::too_many_arguments)] - pub fn new( - system: System, - get_ids: Vec, - hnsw_segment_id: Uuid, - collection_id: CollectionUuid, - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - blockfile_provider: BlockfileProvider, - collection_version: u32, - log_position: u64, - ) -> Self { - Self { - state: ExecutionState::Pending, - system, - search_user_ids: get_ids, - hnsw_segment_id, - collection_id, - log, - sysdb, - dispatcher, - blockfile_provider, - record_segment: None, - collection: None, - result_channel: None, - collection_version, - log_position, - } - } - - async fn pull_logs( - &mut self, - self_address: Box>>, - ctx: &ComponentContext, - ) { - self.state = ExecutionState::PullLogs; - let operator = PullLogsOperator::new(self.log.clone()); - let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH); - let end_timestamp = match end_timestamp { - // TODO: change protobuf definition to use u64 instead of i64 - Ok(end_timestamp) => end_timestamp.as_nanos() as i64, - Err(e) => { - terminate_with_error( - self.result_channel.take(), - Box::new(GetVectorsError::SystemTimeError(e)), - ctx, - ); - return; - } - }; - - let collection = self - .collection - .as_ref() - .expect("State machine invariant violation. The collection is not set when pulling logs. This should never happen."); - - let input = PullLogsInput::new( - collection.collection_id, - // The collection log position is inclusive, and we want to start from the next log - // Note that we query using the incoming log position this is critical for correctness - // TODO: We should make all the log service code use u64 instead of i64 - (self.log_position as i64) + 1, - 100, - None, - Some(end_timestamp), - ); - - let task = wrap(operator, input, self_address); - // Wrap the task with current span as the parent. The worker then executes it - // inside a child span with this parent. - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - terminate_with_error( - self.result_channel.take(), - Box::new(GetVectorsError::TaskSendError(e)), - ctx, - ); - } - } - } - - async fn get_vectors( - &mut self, - self_address: Box< - dyn ReceiverForMessage>, - >, - log: Chunk, - ctx: &ComponentContext, - ) { - self.state = ExecutionState::GetVectors; - let record_segment = self - .record_segment - .as_ref() - .expect("Invariant violation. Record segment is not set."); - let blockfile_provider = self.blockfile_provider.clone(); - let operator = GetVectorsOperator::new(); - tracing::info!("get_vectors with search ids {:?}", self.search_user_ids); - let input = GetVectorsOperatorInput::new( - record_segment.clone(), - blockfile_provider, - log, - self.search_user_ids.clone(), - ); - - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - terminate_with_error( - self.result_channel.take(), - Box::new(GetVectorsError::TaskSendError(e)), - ctx, - ); - } - } - } - - /// Run the orchestrator and return the result. - /// # Note - /// Use this over spawning the component directly. This method will start the component and - /// wait for it to finish before returning the result. - pub(crate) async fn run(mut self) -> Result> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = self.system.clone().start_component(self); - let result = rx.await; - handle.stop(); - result.unwrap() - } -} - -// ============== Component Implementation ============== - -#[async_trait] -impl Component for GetVectorsOrchestrator { - fn get_name() -> &'static str { - "GetVectorsOrchestrator" - } - - fn queue_size(&self) -> usize { - 1000 - } - - async fn on_start(&mut self, ctx: &ComponentContext) { - // Populate the orchestrator with the initial state - The HNSW Segment, The Record Segment and the Collection - let hnsw_segment = match get_hnsw_segment_by_id( - self.sysdb.clone(), - &self.hnsw_segment_id, - &self.collection_id, - ) - .await - { - Ok(segment) => segment, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let collection_id = &hnsw_segment.collection; - - let collection = match get_collection_by_id(self.sysdb.clone(), collection_id).await { - Ok(collection) => collection, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - // If the collection version does not match the request version then we terminate with an error - if collection.version as u32 != self.collection_version { - terminate_with_error( - self.result_channel.take(), - Box::new(GetVectorsError::CollectionVersionMismatch), - ctx, - ); - return; - } - - let record_segment = - match get_record_segment_by_collection_id(self.sysdb.clone(), collection_id).await { - Ok(segment) => segment, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - self.record_segment = Some(record_segment); - self.collection = Some(collection); - - self.pull_logs(ctx.receiver(), ctx).await; - } -} - -// ============== Handlers ============== - -#[async_trait] -impl Handler> for GetVectorsOrchestrator { - type Result = (); - - async fn handle( - &mut self, - message: TaskResult, - ctx: &ComponentContext, - ) { - let message = message.into_inner(); - match message { - Ok(output) => { - let logs = output.logs(); - self.get_vectors(ctx.receiver(), logs, ctx).await; - } - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - } - } -} - -#[async_trait] -impl Handler> - for GetVectorsOrchestrator -{ - type Result = (); - - async fn handle( - &mut self, - message: TaskResult, - ctx: &ComponentContext, - ) { - let message = message.into_inner(); - match message { - Ok(output) => { - let result = GetVectorsResult { - ids: output.ids, - vectors: output.vectors, - }; - let result_channel = self - .result_channel - .take() - .expect("Invariant violation. Result channel is not set."); - match result_channel.send(Ok(result)) { - Ok(_) => (), - Err(_e) => { - // Log an error - this implied the listener was dropped - trace!( - "[GetVectorsOrchestrators] Result channel dropped before sending result" - ); - } - } - // Cancel the orchestrator so it stops processing - ctx.cancellation_token.cancel(); - } - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - } - } -} diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs deleted file mode 100644 index f7dba01e825..00000000000 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ /dev/null @@ -1,864 +0,0 @@ -use super::super::operator::wrap; -use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator}; -use super::common::{ - get_collection_by_id, get_hnsw_segment_by_id, get_record_segment_by_collection_id, - terminate_with_error, -}; -use crate::execution::dispatcher::Dispatcher; -use crate::execution::operator::TaskResult; -use crate::execution::operators::brute_force_knn::{ - BruteForceKnnOperator, BruteForceKnnOperatorError, BruteForceKnnOperatorInput, - BruteForceKnnOperatorOutput, -}; -use crate::execution::operators::hnsw_knn::{ - HnswKnnOperator, HnswKnnOperatorInput, HnswKnnOperatorOutput, -}; -use crate::execution::operators::merge_knn_results::{ - MergeKnnBruteForceResultInput, MergeKnnResultsOperator, MergeKnnResultsOperatorInput, - MergeKnnResultsOperatorOutput, -}; -use crate::execution::operators::pull_log::PullLogsOutput; -use crate::execution::operators::record_segment_prefetch::{ - Keys, OffsetIdToDataKeys, OffsetIdToUserIdKeys, RecordSegmentPrefetchIoInput, - RecordSegmentPrefetchIoOperator, RecordSegmentPrefetchIoOperatorError, - RecordSegmentPrefetchIoOutput, -}; -use crate::log::log::PullLogsError; -use crate::segment::distributed_hnsw_segment::{ - DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader, -}; -use crate::segment::utils::distance_function_from_segment; -use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}; -use crate::system::{ComponentContext, ComponentHandle, System}; -use crate::{ - log::log::Log, - system::{Component, Handler, ReceiverForMessage}, -}; -use async_trait::async_trait; -use chroma_blockstore::provider::BlockfileProvider; -use chroma_distance::normalize; -use chroma_distance::DistanceFunction; -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::hnsw_provider::HnswIndexProvider; -use chroma_index::IndexConfig; -use chroma_types::{Chunk, Collection, CollectionUuid, LogRecord, Segment, VectorQueryResult}; -use std::collections::HashMap; -use std::fmt::Debug; -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; -use thiserror::Error; -use tracing::{trace, Span}; -use uuid::Uuid; - -/** The state of the orchestrator. -In chroma, we have a relatively fixed number of query plans that we can execute. Rather -than a flexible state machine abstraction, we just manually define the states that we -expect to encounter for a given query plan. This is a bit more rigid, but it's also simpler and easier to -understand. We can always add more abstraction later if we need it. -```plaintext - - ┌───► Brute Force ─────┐ - │ │ - Pending ─► PullLogs ─► Group │ ├─► MergeResults ─► Finished - │ │ - └───► HNSW ────────────┘ - -``` -*/ -#[derive(Debug)] -#[allow(dead_code)] -enum ExecutionState { - Pending, - PullLogs, - Partition, - QueryKnn, // This is both the Brute force and HNSW query state - MergeResults, - Finished, -} - -#[derive(Error, Debug)] -#[allow(dead_code)] -enum HnswSegmentQueryError { - #[error(transparent)] - GetByIdError(#[from] super::common::GetHnswSegmentByIdError), - #[error("Get segments error: {0}")] - GetSegmentsError(#[from] GetSegmentsError), - #[error("Get collection error: {0}")] - GetCollectionError(#[from] GetCollectionsError), - #[error("Collection has no dimension set")] - CollectionHasNoDimension, - #[error("Collection version mismatch")] - CollectionVersionMismatch, -} - -impl ChromaError for HnswSegmentQueryError { - fn code(&self) -> ErrorCodes { - match self { - HnswSegmentQueryError::GetByIdError(e) => e.code(), - HnswSegmentQueryError::GetSegmentsError(_) => ErrorCodes::Internal, - HnswSegmentQueryError::GetCollectionError(_) => ErrorCodes::Internal, - HnswSegmentQueryError::CollectionHasNoDimension => ErrorCodes::InvalidArgument, - HnswSegmentQueryError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, - } - } -} - -#[derive(Debug)] -pub(crate) struct HnswQueryOrchestrator { - state: ExecutionState, - // Component Execution - system: System, - // Query state - query_vectors: Vec>, - k: i32, - allowed_ids: Arc<[String]>, - include_embeddings: bool, - hnsw_segment_id: Uuid, - collection_id: CollectionUuid, - // State fetched or created for query execution - hnsw_segment: Option, - record_segment: Option, - collection: Option, - index_config: Option, - // query_vectors index to the result - hnsw_result_offset_ids: HashMap>, - hnsw_result_distances: HashMap>, - brute_force_results: HashMap, - // Task id to query_vectors index - hnsw_task_id_to_query_index: HashMap, - brute_force_task_id_to_query_index: HashMap, - merge_task_id_to_query_index: HashMap, - // Result state - results: Option>>, - // State machine management - merge_dependency_count: u32, - finish_dependency_count: u32, - // Services - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - hnsw_index_provider: HnswIndexProvider, - blockfile_provider: BlockfileProvider, - // Result channel - #[allow(clippy::type_complexity)] - result_channel: Option< - tokio::sync::oneshot::Sender>, Box>>, - >, - // Request version context - collection_version: u32, - log_position: u64, -} - -#[allow(dead_code)] -impl HnswQueryOrchestrator { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - system: System, - query_vectors: Vec>, - k: i32, - allowed_ids: Vec, - include_embeddings: bool, - segment_id: Uuid, - collection_id: CollectionUuid, - log: Box, - sysdb: Box, - hnsw_index_provider: HnswIndexProvider, - blockfile_provider: BlockfileProvider, - dispatcher: ComponentHandle, - collection_version: u32, - log_position: u64, - ) -> Self { - // Set the merge dependency count to the number of query vectors * 2 - // N for the HNSW query and N for the Brute force query - let merge_dependency_count = (query_vectors.len() * 2) as u32; - // Set the finish dependency count to the number of query vectors - // since each query vector will have a merge task - let finish_dependency_count = query_vectors.len() as u32; - // pre-allocate the result vectors - let results = Some(Vec::with_capacity(query_vectors.len())); - tracing::info!( - "Performing KNN for k = {}, num allowed_ids = {:?}, num query vectors = {:?}", - k, - allowed_ids.len(), - query_vectors.len() - ); - - HnswQueryOrchestrator { - state: ExecutionState::Pending, - system, - merge_dependency_count, - finish_dependency_count, - query_vectors, - k, - allowed_ids: allowed_ids.into(), - include_embeddings, - hnsw_segment_id: segment_id, - collection_id, - hnsw_segment: None, - record_segment: None, - collection: None, - index_config: None, - hnsw_result_offset_ids: HashMap::new(), - hnsw_result_distances: HashMap::new(), - brute_force_results: HashMap::new(), - hnsw_task_id_to_query_index: HashMap::new(), - brute_force_task_id_to_query_index: HashMap::new(), - merge_task_id_to_query_index: HashMap::new(), - results, - log, - sysdb, - dispatcher, - hnsw_index_provider, - blockfile_provider, - result_channel: None, - collection_version, - log_position, - } - } - - async fn pull_logs( - &mut self, - self_address: Box>>, - ) { - self.state = ExecutionState::PullLogs; - let operator = PullLogsOperator::new(self.log.clone()); - let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH); - let end_timestamp = match end_timestamp { - // TODO: change protobuf definition to use u64 instead of i64 - Ok(end_timestamp) => end_timestamp.as_nanos() as i64, - Err(_) => { - // Log an error and reply + return - return; - } - }; - - let collection = self - .collection - .as_ref() - .expect("State machine invariant violation. The collection is not set when pulling logs. This should never happen."); - - let input = PullLogsInput::new( - collection.collection_id, - // The collection log position is inclusive, and we want to start from the next log - // Note that we query using the incoming log position this is critical for correctness - // TODO: We should make all the log service code use u64 instead of i64 - (self.log_position as i64) + 1, - 100, - None, - Some(end_timestamp), - ); - let task = wrap(operator, input, self_address); - // Wrap the task with current span as the parent. The worker then executes it - // inside a child span with this parent. - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - // TODO: log an error and reply to caller - tracing::error!("Error sending PullLogs task: {:?}", e); - } - } - } - - async fn brute_force_query( - &mut self, - logs: Chunk, - self_address: Box< - dyn ReceiverForMessage< - TaskResult, - >, - >, - ) { - self.state = ExecutionState::QueryKnn; - let distance_function = &self - .index_config - .as_ref() - .expect("Invariant violation. Index config is not set") - .distance_function; - - // TODO: We shouldn't have to clone query vectors here. We should be able to pass a Arc<[f32]>-like to the input - for (i, query_vector) in self.query_vectors.iter().enumerate() { - let bf_input = BruteForceKnnOperatorInput { - log: logs.clone(), - query: query_vector.clone(), - k: self.k as usize, - distance_metric: distance_function.clone(), - allowed_ids: self.allowed_ids.clone(), - record_segment_definition: self - .record_segment - .as_ref() - .expect("Invariant violation. Record segment is not set") - .clone(), - blockfile_provider: self.blockfile_provider.clone(), - }; - let operator = Box::new(BruteForceKnnOperator {}); - let task = wrap(operator, bf_input, self_address.clone()); - self.brute_force_task_id_to_query_index.insert(task.id(), i); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - // Log an error - tracing::error!("Error sending Brute Force KNN task: {:?}", e); - } - } - } - } - - async fn hnsw_segment_query(&mut self, logs: Chunk, ctx: &ComponentContext) { - self.state = ExecutionState::QueryKnn; - - let hnsw_segment = self - .hnsw_segment - .as_ref() - .expect("Invariant violation. HNSW Segment is not set"); - let dimensionality = self - .collection - .as_ref() - .expect("Invariant violation. Collection is not set") - .dimension - .expect("Invariant violation. Collection dimension is not set"); - - // Fetch the data needed for the duration of the query - The HNSW Segment, The record Segment and the Collection - let hnsw_segment_reader = match DistributedHNSWSegmentReader::from_segment( - // These unwraps are safe because we have already checked that the segments are set in the orchestrator on_start - hnsw_segment, - dimensionality as usize, - self.hnsw_index_provider.clone(), - ) - .await - { - Ok(reader) => reader, - Err(e) => { - match *e { - DistributedHNSWSegmentFromSegmentError::Uninitialized => { - tracing::info!("[HnswQueryOperation]: Uninitialied reader {:?}", *e); - // no task, decrement the merge dependency count and return - // with an empty result - for (i, _) in self.query_vectors.iter().enumerate() { - self.merge_dependency_count -= 1; - self.hnsw_result_distances.insert(i, Vec::new()); - self.hnsw_result_offset_ids.insert(i, Vec::new()); - } - return; - } - _ => { - tracing::error!("[HnswQueryOperation]: Error creating distributed hnsw segment reader {:?}", *e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - } - } - }; - - let record_segment = self - .record_segment - .as_ref() - .expect("Invariant violation. Record Segment is not set"); - - // Dispatch a query task per query vector - for (i, query_vector) in self.query_vectors.iter().enumerate() { - let operator = Box::new(HnswKnnOperator {}); - let input = HnswKnnOperatorInput { - segment: hnsw_segment_reader.clone(), - query: query_vector.clone(), - k: self.k as usize, - record_segment: record_segment.clone(), - blockfile_provider: self.blockfile_provider.clone(), - allowed_ids: self.allowed_ids.clone(), - logs: logs.clone(), - }; - let task = wrap(operator, input, ctx.receiver()); - self.hnsw_task_id_to_query_index.insert(task.id(), i); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - // Log an error - tracing::error!("Error sending HNSW KNN task: {:?}", e); - } - } - } - } - - async fn merge_results(&mut self, ctx: &ComponentContext) { - self.state = ExecutionState::MergeResults; - for i in 0..self.query_vectors.len() { - self.merge_results_for_index(ctx, i).await; - } - } - - async fn prefetch_record_data(&mut self, ctx: &ComponentContext, offset_ids: Vec) { - let record_segment = self - .record_segment - .as_ref() - .expect("Invariant violation. Record Segment is not set"); - // TODO: Divide this into multiple tasks based on some criteria. - let offsetid_to_data_keys = - Keys::OffsetIdToDataKeys(OffsetIdToDataKeys { keys: offset_ids }); - let prefetch_input = RecordSegmentPrefetchIoInput { - keys: offsetid_to_data_keys, - segment: record_segment.clone(), - provider: self.blockfile_provider.clone(), - }; - let operator = RecordSegmentPrefetchIoOperator::new(); - let prefetch_task = wrap(operator, prefetch_input, ctx.receiver()); - match self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - Ok(_) => (), - Err(e) => { - // Log an error - tracing::error!("Error sending record segment Prefetch data task: {:?}", e); - } - } - } - - async fn prefetch_user_ids(&mut self, ctx: &ComponentContext, offset_ids: Vec) { - let record_segment = self - .record_segment - .as_ref() - .expect("Invariant violation. Record Segment is not set"); - // TODO: Divide this into multiple tasks based on some criteria. - let offsetid_to_userid_keys = - Keys::OffsetIdToUserIdKeys(OffsetIdToUserIdKeys { keys: offset_ids }); - let prefetch_input = RecordSegmentPrefetchIoInput { - keys: offsetid_to_userid_keys, - segment: record_segment.clone(), - provider: self.blockfile_provider.clone(), - }; - let operator = RecordSegmentPrefetchIoOperator::new(); - let prefetch_task = wrap(operator, prefetch_input, ctx.receiver()); - match self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - Ok(_) => (), - Err(e) => { - // Log an error - tracing::error!("Error sending Prefetch data task: {:?}", e); - } - } - } - - async fn merge_results_for_index( - &mut self, - ctx: &ComponentContext, - query_vector_index: usize, - ) { - let hnsw_result_offset_ids = self - .hnsw_result_offset_ids - .remove(&query_vector_index) - .expect( - "Invariant violation. HNSW result offset ids are not set for query vector index", - ); - - if !hnsw_result_offset_ids.is_empty() { - // Eagerly dispatch prefetch tasks. - let offset_ids_to_prefetch: Vec = - hnsw_result_offset_ids.iter().map(|x| *x as u32).collect(); - self.prefetch_record_data(ctx, offset_ids_to_prefetch.clone()) - .await; - self.prefetch_user_ids(ctx, offset_ids_to_prefetch).await; - } - - let record_segment = self - .record_segment - .as_ref() - .expect("Invariant violation. Record Segment is not set"); - - let hnsw_result_distances = self - .hnsw_result_distances - .remove(&query_vector_index) - .expect( - "Invariant violation. HNSW result distances are not set for query vector index", - ); - - let brute_force_result = self.brute_force_results.remove(&query_vector_index); - - tracing::info!( - "[HnswQueryOperation]: Brute force {} user ids, hnsw {} offset ids, hnsw ids: {:?}...", - brute_force_result.as_ref().map_or(0, |x| x.user_ids.len()), - hnsw_result_offset_ids.len(), - &hnsw_result_offset_ids, - ); - - let operator = Box::new(MergeKnnResultsOperator {}); - let input = MergeKnnResultsOperatorInput::new( - hnsw_result_offset_ids, - hnsw_result_distances, - brute_force_result.map(|r| MergeKnnBruteForceResultInput { - user_ids: r.user_ids, - distances: r.distances, - vectors: r.embeddings, - }), - self.include_embeddings, - self.k as usize, - record_segment.clone(), - self.blockfile_provider.clone(), - ); - - let task = wrap(operator, input, ctx.receiver()); - self.merge_task_id_to_query_index - .insert(task.id(), query_vector_index); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - // Log an error - tracing::error!("Error sending Merge KNN task: {:?}", e); - } - } - } - - fn terminate_with_empty_response(&mut self, ctx: &ComponentContext) { - let result_channel = self - .result_channel - .take() - .expect("Invariant violation. Result channel is not set."); - let mut empty_resp = vec![]; - for _ in 0..self.query_vectors.len() { - empty_resp.push(vec![]); - } - match result_channel.send(Ok(empty_resp)) { - Ok(_) => (), - Err(_) => { - // Log an error - this implied the listener was dropped - tracing::error!( - "[HnswQueryOrchestrator] Result channel dropped before sending empty response" - ); - } - } - // Cancel the orchestrator so it stops processing - ctx.cancellation_token.cancel(); - } - - /// Run the orchestrator and return the result. - /// # Note - /// Use this over spawning the component directly. This method will start the component and - /// wait for it to finish before returning the result. - pub(crate) async fn run(mut self) -> Result>, Box> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = self.system.clone().start_component(self); - let result = rx.await; - handle.stop(); - result.unwrap() - } -} - -// ============== Component Implementation ============== - -#[async_trait] -impl Component for HnswQueryOrchestrator { - fn get_name() -> &'static str { - "HNSW Query orchestrator" - } - - fn queue_size(&self) -> usize { - 1000 // TODO: make configurable - } - - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { - // Populate the orchestrator with the initial state - The HNSW Segment, The Record Segment and the Collection - let hnsw_segment = match get_hnsw_segment_by_id( - self.sysdb.clone(), - &self.hnsw_segment_id, - &self.collection_id, - ) - .await - { - Ok(segment) => segment, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let collection_id = &hnsw_segment.collection; - - let collection = match get_collection_by_id(self.sysdb.clone(), collection_id).await { - Ok(collection) => collection, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - // If the collection version does not match the request version then we terminate with an error - if collection.version as u32 != self.collection_version { - terminate_with_error( - self.result_channel.take(), - Box::new(HnswSegmentQueryError::CollectionVersionMismatch), - ctx, - ); - return; - } - - // If segment is uninitialized and dimension is not set then we assume - // that this is a query before any add so return empty response. - if hnsw_segment.file_path.is_empty() && collection.dimension.is_none() { - self.terminate_with_empty_response(ctx); - return; - } - - // Validate that the collection has a dimension set. Downstream steps will rely on this - // so that they can unwrap the dimension without checking for None - if collection.dimension.is_none() { - terminate_with_error( - self.result_channel.take(), - Box::new(HnswSegmentQueryError::CollectionHasNoDimension), - ctx, - ); - return; - }; - - let record_segment = - match get_record_segment_by_collection_id(self.sysdb.clone(), collection_id).await { - Ok(segment) => segment, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let distance_function = match distance_function_from_segment(&hnsw_segment) { - Ok(distance_function) => distance_function, - Err(e) => { - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - self.index_config = Some(IndexConfig::new( - collection.dimension.unwrap(), - distance_function, - )); - // Normalize the query vectors if we are using the cosine similarity - if self.index_config.as_ref().unwrap().distance_function == DistanceFunction::Cosine { - for query_vector in self.query_vectors.iter_mut() { - *query_vector = normalize(query_vector); - } - } - - self.record_segment = Some(record_segment); - self.hnsw_segment = Some(hnsw_segment); - self.collection = Some(collection); - - self.pull_logs(ctx.receiver()).await; - } -} - -// ============== Handlers ============== - -#[async_trait] -impl Handler> for HnswQueryOrchestrator { - type Result = (); - - async fn handle( - &mut self, - message: TaskResult, - ctx: &crate::system::ComponentContext, - ) { - let message = message.into_inner(); - self.state = ExecutionState::Partition; - - match message { - Ok(pull_logs_output) => { - let logs = pull_logs_output.logs(); - if !logs.is_empty() { - self.brute_force_query(logs.clone(), ctx.receiver()).await; - } else { - // Skip running the brute force query if there are no logs - self.merge_dependency_count -= self.query_vectors.len() as u32; - } - - self.hnsw_segment_query(logs, ctx).await; - } - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - } - } -} - -#[async_trait] -impl Handler> - for HnswQueryOrchestrator -{ - type Result = (); - - async fn handle( - &mut self, - message: TaskResult, - ctx: &crate::system::ComponentContext, - ) { - let task_id = message.id(); - let message = message.into_inner(); - let query_index = self - .brute_force_task_id_to_query_index - .remove(&task_id) - .expect("Invariant violation. Brute force task id is not set for query vector index"); - - match message { - Ok(output) => { - self.brute_force_results.insert(query_index, output); - } - Err(e) => { - terminate_with_error(self.result_channel.take(), e.boxed(), ctx); - return; - } - } - - self.merge_dependency_count -= 1; - - if self.merge_dependency_count == 0 { - self.merge_results(ctx).await; - } - } -} - -#[async_trait] -impl Handler>> for HnswQueryOrchestrator { - type Result = (); - - async fn handle( - &mut self, - message: TaskResult>, - ctx: &ComponentContext, - ) { - let task_id = message.id(); - let message = message.into_inner(); - let query_index = self - .hnsw_task_id_to_query_index - .remove(&task_id) - .expect("Invariant violation. HNSW task id is not set for query vector index"); - match message { - Ok(output) => { - self.hnsw_result_offset_ids - .insert(query_index, output.offset_ids); - self.hnsw_result_distances - .insert(query_index, output.distances); - } - Err(e) => { - terminate_with_error(self.result_channel.take(), e.boxed(), ctx); - return; - } - } - - self.merge_dependency_count -= 1; - - if self.merge_dependency_count == 0 { - self.merge_results(ctx).await; - } - } -} - -#[async_trait] -impl Handler>> - for HnswQueryOrchestrator -{ - type Result = (); - - async fn handle( - &mut self, - message: TaskResult>, - ctx: &crate::system::ComponentContext, - ) { - let task_id = message.id(); - let message = message.into_inner(); - let query_index = self - .merge_task_id_to_query_index - .remove(&task_id) - .expect("Invariant violation. Merge task id is not set for query vector index"); - - self.state = ExecutionState::Finished; - - let (mut output_ids, mut output_distances, output_vectors) = match message { - Ok(output) => (output.user_ids, output.distances, output.vectors), - Err(e) => { - terminate_with_error(self.result_channel.take(), e.boxed(), ctx); - return; - } - }; - - let mut query_results = Vec::new(); - if self.include_embeddings { - for ((index, distance), vector) in - output_ids.drain(..).zip(output_distances.drain(..)).zip( - output_vectors - .expect("Embeddings are expected if include_embeddings is set") - .drain(..), - ) - { - let query_result = VectorQueryResult { - id: index, - distance, - vector: Some(vector), - }; - query_results.push(query_result); - } - } else { - for (index, distance) in output_ids.drain(..).zip(output_distances.drain(..)) { - let query_result = VectorQueryResult { - id: index, - distance, - vector: None, - }; - query_results.push(query_result); - } - } - trace!("Merged results: {:?}", query_results); - - let results_slice = self - .results - .as_mut() - .expect("Invariant violation. Results are not set") - .spare_capacity_mut(); - results_slice[query_index].write(query_results); - self.finish_dependency_count -= 1; - - if self.finish_dependency_count == 0 { - let result_channel = match self.result_channel.take() { - Some(tx) => tx, - None => { - // Log an error - this is an invariant violation, the result channel should always be set - return; - } - }; - - unsafe { - // Safety: We have ensured that the results are set and the length is equal to the number of query vectors - // https://doc.rust-lang.org/stable/std/mem/union.MaybeUninit.html#out-pointers - self.results - .as_mut() - .expect("Invariant violation. Results are not set") - .set_len(self.query_vectors.len()); - } - - match result_channel.send(Ok(self - .results - .take() - .expect("Invariant violation. Results are not set"))) - { - Ok(_) => (), - Err(_) => { - // Log an error - } - } - } - } -} - -#[async_trait] -impl Handler> - for HnswQueryOrchestrator -{ - type Result = (); - - async fn handle( - &mut self, - _message: TaskResult, - _ctx: &ComponentContext, - ) { - // Nothing to do. - } -} diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index 4c91d480d87..d9b83d6e48a 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,9 +1,6 @@ mod common; mod compact; mod count; -mod get_vectors; -#[allow(dead_code)] -mod hnsw; mod spann_knn; pub(crate) use compact::*; pub(crate) use count::*; diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index b27f6361162..f69ad5e7f04 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -774,19 +774,6 @@ impl RecordSegmentReader<'_> { self.curr_max_offset_id.clone() } - pub(crate) async fn get_user_id_for_offset_id( - &self, - offset_id: u32, - ) -> Result<&str, Box> { - match self.id_to_user_id.get("", offset_id).await { - Ok(Some(user_id)) => Ok(user_id), - Ok(None) => Err(Box::new( - RecordSegmentReaderCreationError::UserRecordNotFound(offset_id.to_string()), - )), - Err(e) => Err(e), - } - } - pub(crate) async fn get_offset_id_for_user_id( &self, user_id: &str, @@ -939,13 +926,6 @@ impl RecordSegmentReader<'_> { .load_blocks_for_keys(&prefixes, &keys) .await } - - pub(crate) async fn prefetch_id_to_user_id(&self, keys: &[u32]) { - let prefixes = vec![""; keys.len()]; - self.id_to_user_id - .load_blocks_for_keys(&prefixes, keys) - .await - } } #[cfg(test)]