Skip to content

Commit

Permalink
Better abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Nov 11, 2024
1 parent 74dac60 commit 861d9c0
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 99 deletions.
67 changes: 62 additions & 5 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl SpannIndexWriter {
}
}

pub async fn hnsw_index_from_id(
async fn hnsw_index_from_id(
hnsw_provider: &HnswIndexProvider,
id: &Uuid,
collection_id: &Uuid,
Expand All @@ -73,7 +73,7 @@ impl SpannIndexWriter {
}
}

pub async fn create_hnsw_index(
async fn create_hnsw_index(
hnsw_provider: &HnswIndexProvider,
collection_id: &Uuid,
distance_function: DistanceFunction,
Expand All @@ -96,7 +96,7 @@ impl SpannIndexWriter {
}
}

pub async fn load_versions_map(
async fn load_versions_map(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<HashMap<u32, u32>, SpannIndexWriterConstructionError> {
Expand All @@ -116,7 +116,7 @@ impl SpannIndexWriter {
Ok(versions_map)
}

pub async fn fork_postings_list(
async fn fork_postings_list(
blockfile_id: &Uuid,
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileWriter, SpannIndexWriterConstructionError> {
Expand All @@ -129,12 +129,69 @@ impl SpannIndexWriter {
}
}

pub async fn create_posting_list(
async fn create_posting_list(
blockfile_provider: &BlockfileProvider,
) -> Result<BlockfileWriter, SpannIndexWriterConstructionError> {
match blockfile_provider.create::<u32, &SpannPostingList<'_>>() {
Ok(writer) => Ok(writer),
Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError),
}
}

#[allow(clippy::too_many_arguments)]
pub async fn from_id(
hnsw_provider: &HnswIndexProvider,
hnsw_id: Option<&Uuid>,
versions_map_id: Option<&Uuid>,
posting_list_id: Option<&Uuid>,
hnsw_params: Option<HnswIndexParams>,
collection_id: &Uuid,
distance_function: DistanceFunction,
dimensionality: usize,
blockfile_provider: &BlockfileProvider,
) -> Result<Self, SpannIndexWriterConstructionError> {
// Create the HNSW index.
let hnsw_index = match hnsw_id {
Some(hnsw_id) => {
Self::hnsw_index_from_id(
hnsw_provider,
hnsw_id,
collection_id,
distance_function,
dimensionality,
)
.await?
}
None => {
Self::create_hnsw_index(
hnsw_provider,
collection_id,
distance_function,
dimensionality,
hnsw_params.unwrap(), // Safe since caller should always provide this.
)
.await?
}
};
// Load the versions map.
let versions_map = match versions_map_id {
Some(versions_map_id) => {
Self::load_versions_map(versions_map_id, blockfile_provider).await?
}
None => HashMap::new(),
};
// Fork the posting list writer.
let posting_list_writer = match posting_list_id {
Some(posting_list_id) => {
Self::fork_postings_list(posting_list_id, blockfile_provider).await?
}
None => Self::create_posting_list(blockfile_provider).await?,
};
Ok(Self::new(
hnsw_index,
hnsw_provider.clone(),
posting_list_writer,
versions_map,
))
}
}
134 changes: 40 additions & 94 deletions rust/worker/src/segment/spann_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,26 @@ pub enum SpannSegmentWriterError {
DistanceFunctionNotFound,
#[error("Hnsw index id parsing error")]
IndexIdParsingError,
#[error("HNSW index creation error")]
HnswIndexCreationError,
#[error("Hnsw Invalid file path")]
HnswInvalidFilePath,
#[error("Version map Invalid file path")]
VersionMapInvalidFilePath,
#[error("Failure in loading the versions map")]
VersionMapLoadError,
#[error("Failure in forking the posting list")]
PostingListForkError,
#[error("Postings list invalid file path")]
PostingListInvalidFilePath,
#[error("Posting list creation error")]
PostingListCreationError,
#[error("Spann index creation error")]
SpannIndexWriterConstructionError,
}

impl ChromaError for SpannSegmentWriterError {
fn code(&self) -> ErrorCodes {
match self {
Self::InvalidArgument => ErrorCodes::InvalidArgument,
Self::IndexIdParsingError => ErrorCodes::Internal,
Self::HnswIndexCreationError => ErrorCodes::Internal,
Self::DistanceFunctionNotFound => ErrorCodes::Internal,
Self::HnswInvalidFilePath => ErrorCodes::Internal,
Self::VersionMapInvalidFilePath => ErrorCodes::Internal,
Self::VersionMapLoadError => ErrorCodes::Internal,
Self::PostingListForkError => ErrorCodes::Internal,
Self::PostingListInvalidFilePath => ErrorCodes::Internal,
Self::PostingListCreationError => ErrorCodes::Internal,
Self::SpannIndexWriterConstructionError => ErrorCodes::Internal,
}
}
}
Expand All @@ -64,14 +55,20 @@ impl SpannSegmentWriter {
pub async fn from_segment(
segment: &Segment,
blockfile_provider: &BlockfileProvider,
hnsw_provider: HnswIndexProvider,
hnsw_provider: &HnswIndexProvider,
dimensionality: usize,
) -> Result<SpannSegmentWriter, SpannSegmentWriterError> {
if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR {
return Err(SpannSegmentWriterError::InvalidArgument);
}
// Load HNSW index.
let hnsw_index = match segment.file_path.get(HNSW_PATH) {
let distance_function = match distance_function_from_segment(segment) {
Ok(distance_function) => distance_function,
Err(e) => {
return Err(SpannSegmentWriterError::DistanceFunctionNotFound);
}
};
let (hnsw_id, hnsw_params) = match segment.file_path.get(HNSW_PATH) {
Some(hnsw_path) => match hnsw_path.first() {
Some(index_id) => {
let index_uuid = match Uuid::parse_str(index_id) {
Expand All @@ -80,85 +77,33 @@ impl SpannSegmentWriter {
return Err(SpannSegmentWriterError::IndexIdParsingError);
}
};
let distance_function = match distance_function_from_segment(segment) {
Ok(distance_function) => distance_function,
Err(e) => {
return Err(SpannSegmentWriterError::DistanceFunctionNotFound);
}
};
match SpannIndexWriter::hnsw_index_from_id(
&hnsw_provider,
&index_uuid,
&segment.collection,
distance_function,
dimensionality,
)
.await
{
Ok(index) => index,
Err(_) => {
return Err(SpannSegmentWriterError::HnswIndexCreationError);
}
}
(Some(index_uuid), Some(hnsw_params_from_segment(segment)))
}
None => {
return Err(SpannSegmentWriterError::HnswInvalidFilePath);
}
},
// Create a new index.
None => {
let hnsw_params = hnsw_params_from_segment(segment);

let distance_function = match distance_function_from_segment(segment) {
Ok(distance_function) => distance_function,
Err(e) => {
return Err(SpannSegmentWriterError::DistanceFunctionNotFound);
}
};

match SpannIndexWriter::create_hnsw_index(
&hnsw_provider,
&segment.collection,
distance_function,
dimensionality,
hnsw_params,
)
.await
{
Ok(index) => index,
Err(_) => {
return Err(SpannSegmentWriterError::HnswIndexCreationError);
}
}
}
None => (None, None),
};
// Load version map. Empty if file path is not set.
let mut version_map = HashMap::new();
if let Some(version_map_path) = segment.file_path.get(VERSION_MAP_PATH) {
version_map = match version_map_path.first() {
let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) {
Some(version_map_path) => match version_map_path.first() {
Some(version_map_id) => {
let version_map_uuid = match Uuid::parse_str(version_map_id) {
Ok(uuid) => uuid,
Err(_) => {
return Err(SpannSegmentWriterError::IndexIdParsingError);
}
};
match SpannIndexWriter::load_versions_map(&version_map_uuid, blockfile_provider)
.await
{
Ok(index) => index,
Err(_) => {
return Err(SpannSegmentWriterError::VersionMapLoadError);
}
}
Some(version_map_uuid)
}
None => {
return Err(SpannSegmentWriterError::VersionMapInvalidFilePath);
}
}
}
},
None => None,
};
// Fork the posting list map.
let posting_list_writer = match segment.file_path.get(POSTING_LIST_PATH) {
let posting_list_id = match segment.file_path.get(POSTING_LIST_PATH) {
Some(posting_list_path) => match posting_list_path.first() {
Some(posting_list_id) => {
let posting_list_uuid = match Uuid::parse_str(posting_list_id) {
Expand All @@ -167,33 +112,34 @@ impl SpannSegmentWriter {
return Err(SpannSegmentWriterError::IndexIdParsingError);
}
};
match SpannIndexWriter::fork_postings_list(
&posting_list_uuid,
blockfile_provider,
)
.await
{
Ok(writer) => writer,
Err(_) => {
return Err(SpannSegmentWriterError::PostingListForkError);
}
}
Some(posting_list_uuid)
}
None => {
return Err(SpannSegmentWriterError::PostingListInvalidFilePath);
}
},
// Create a new index.
None => match SpannIndexWriter::create_posting_list(blockfile_provider).await {
Ok(writer) => writer,
Err(_) => {
return Err(SpannSegmentWriterError::PostingListCreationError);
}
},
None => None,
};

let index_writer =
SpannIndexWriter::new(hnsw_index, hnsw_provider, posting_list_writer, version_map);
let index_writer = match SpannIndexWriter::from_id(
hnsw_provider,
hnsw_id.as_ref(),
versions_map_id.as_ref(),
posting_list_id.as_ref(),
hnsw_params,
&segment.collection,
distance_function,
dimensionality,
blockfile_provider,
)
.await
{
Ok(index_writer) => index_writer,
Err(_) => {
return Err(SpannSegmentWriterError::SpannIndexWriterConstructionError);
}
};

Ok(SpannSegmentWriter {
index: index_writer,
Expand Down

0 comments on commit 861d9c0

Please sign in to comment.