Skip to content

Commit

Permalink
Add _validate_query_type_parameters function
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 23, 2024
1 parent 2f3c9e2 commit b35e63a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):
MIN_SCORE_QUERY_TYPE = "min_score"
MAX_DISTANCE_QUERY_TYPE = "max_distance"
KNN_QUERY_TYPE = "knn"
RADIAL_SEARCH_QUERY_RESULT_SIZE = 10000

def __init__(self, workloads, params, query_params, **kwargs):
super().__init__(workloads, params, Context.QUERY, **kwargs)
Expand All @@ -1062,6 +1063,7 @@ def __init__(self, workloads, params, query_params, **kwargs):
if self.PARAMS_NAME_MIN_SCORE in params:
self.score = parse_float_parameter(self.PARAMS_NAME_MIN_SCORE, params)
self.query_type = self.MIN_SCORE_QUERY_TYPE
self._validate_query_type_parameters()
self.logger.info("query type is set up to %s", self.query_type)
self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1)
self.current_rep = 1
Expand Down Expand Up @@ -1100,6 +1102,11 @@ def __init__(self, workloads, params, query_params, **kwargs):
neighbors_corpora = self.extract_corpora(self.neighbors_data_set_corpus, self.neighbors_data_set_format)
self.corpora.extend(corpora for corpora in neighbors_corpora if corpora not in self.corpora)

def _validate_query_type_parameters(self):
count = sum([self.k is not None, self.distance is not None, self.score is not None])
if count > 1:
raise ValueError("Only one of k, max_distance, or min_score can be specified in vector search.")

@staticmethod
def _validate_neighbors_data_set(file_path, corpus):
if file_path and corpus:
Expand Down Expand Up @@ -1131,8 +1138,7 @@ def _update_body_params(self, vector):
if self.query_type == self.KNN_QUERY_TYPE:
body_params[self.PARAMS_NAME_SIZE] = self.k
else:
# if distance is set, set size to 10000, which is the maximum number results returned by default
body_params[self.PARAMS_NAME_SIZE] = 10000
body_params[self.PARAMS_NAME_SIZE] = self.RADIAL_SEARCH_QUERY_RESULT_SIZE
if self.PARAMS_NAME_QUERY in body_params:
self.logger.warning(
"[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
Expand Down

0 comments on commit b35e63a

Please sign in to comment.