From e3821d90ac44667f3df507b14c4f556246a1eddf Mon Sep 17 00:00:00 2001 From: Finn Roblin Date: Tue, 13 Aug 2024 14:02:42 -0700 Subject: [PATCH] Address Vijay offline feedback Signed-off-by: Finn Roblin --- osbenchmark/utils/dataset.py | 8 +++++ osbenchmark/workload/params.py | 20 ++++++------ tests/workload/params_test.py | 58 +--------------------------------- 3 files changed, 20 insertions(+), 66 deletions(-) diff --git a/osbenchmark/utils/dataset.py b/osbenchmark/utils/dataset.py index 1a46e3d86..974f62421 100644 --- a/osbenchmark/utils/dataset.py +++ b/osbenchmark/utils/dataset.py @@ -24,6 +24,8 @@ class Context(Enum): INDEX = 1 QUERY = 2 NEIGHBORS = 3 + MAX_DISTANCE_NEIGHBORS = 4 + MIN_SCORE_NEIGHBORS = 5 ATTRIBUTES = 7 @@ -142,6 +144,12 @@ def parse_context(context: Context) -> str: if context == Context.QUERY: return "test" + if context == Context.MAX_DISTANCE_NEIGHBORS: + return "max_distance_neighbors" + + if context == Context.MIN_SCORE_NEIGHBORS: + return "min_score_neighbors" + if context == Context.ATTRIBUTES: return "attributes" diff --git a/osbenchmark/workload/params.py b/osbenchmark/workload/params.py index b511c0023..c5341a46a 100644 --- a/osbenchmark/workload/params.py +++ b/osbenchmark/workload/params.py @@ -40,7 +40,7 @@ from osbenchmark import exceptions from osbenchmark.utils import io from osbenchmark.utils.dataset import DataSet, get_data_set, Context -from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter, parse_bool_parameter +from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter from osbenchmark.workload import workload __PARAM_SOURCES_BY_OP = {} @@ -1127,9 +1127,9 @@ def _update_body_params(self, vector): "[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY) self.logger.info("Here, we have query_params: %s ", self.query_params) - efficient_filter=self.query_params.get(self.PARAMS_NAME_FILTER) filter_type=self.query_params.get(self.PARAMS_NAME_FILTER_TYPE) filter_body=self.query_params.get(self.PARAMS_NAME_FILTER_BODY) + efficient_filter = filter_body if filter_type == "efficient" else None # override query params with vector search query body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector, efficient_filter, filter_type, filter_body) @@ -1262,7 +1262,7 @@ def __init__(self, workload, params, **kwargs): self.id_field_name: str = parse_string_parameter( self.PARAMS_NAME_ID_FIELD_NAME, params, self.DEFAULT_ID_FIELD_NAME ) - self.has_attributes = parse_bool_parameter("has_attributes", params, False) + self.filter_attributes: List[Any] = params.get("filter_attributes", []) self.action_buffer = None self.num_nested_vectors = 10 @@ -1294,7 +1294,7 @@ def partition(self, partition_index, total_partitions): ) partition.parent_data_set.seek(partition.offset) - if self.has_attributes: + if self.filter_attributes: partition.attributes_data_set = get_data_set( self.parent_data_set_format, self.parent_data_set_path, Context.ATTRIBUTES ) @@ -1317,8 +1317,10 @@ def bulk_transform_add_attributes(self, partition: np.ndarray, action, attribute partition.tolist(), attributes.tolist(), range(self.current, self.current + len(partition)) ): row = {self.field_name: vec} - for idx, attribute_name, attribute_type in zip(range(3), ["taste", "color", "age"], [str, str, int]): - row.update({attribute_name : attribute_type(attribute_list[idx])}) + for idx, attribute_name in zip(range(len(self.filter_attributes)), self.filter_attributes): + attribute = attribute_list[idx].decode() + if attribute != "None": + row.update({attribute_name : attribute}) if add_id_field_to_body: row.update({self.id_field_name: identifier}) bulk_contents.append(row) @@ -1369,11 +1371,11 @@ def bulk_transform( An array of transformed vectors in bulk format. """ - if not self.is_nested and not self.has_attributes: + if not self.is_nested and not self.filter_attributes: return self.bulk_transform_non_nested(partition, action) # TODO: Assumption: we won't add attributes if we're also doing a nested query. - if self.has_attributes: + if self.filter_attributes: return self.bulk_transform_add_attributes(partition, action, attributes) actions = [] @@ -1457,7 +1459,7 @@ def action(id_field_name, doc_id): else: parent_ids = None - if self.has_attributes: + if self.filter_attributes: attributes = self.attributes_data_set.read(bulk_size) else: attributes = None diff --git a/tests/workload/params_test.py b/tests/workload/params_test.py index 43b028504..5e91d475d 100644 --- a/tests/workload/params_test.py +++ b/tests/workload/params_test.py @@ -2900,62 +2900,6 @@ def test_params_default(self): with self.assertRaises(StopIteration): query_param_source_partition.params() - def test_params_custom_body(self): - # Create a data set - k = 12 - data_set_path = create_data_set( - self.DEFAULT_NUM_VECTORS, - self.DEFAULT_DIMENSION, - self.DEFAULT_TYPE, - Context.QUERY, - self.data_set_dir - ) - neighbors_data_set_path = create_data_set( - self.DEFAULT_NUM_VECTORS, - self.DEFAULT_DIMENSION, - self.DEFAULT_TYPE, - Context.NEIGHBORS, - self.data_set_dir - ) - filter_body = { - "key": "value" - } - - # Create a QueryVectorsFromDataSetParamSource with relevant params - test_param_source_params = { - "field": self.DEFAULT_FIELD_NAME, - "data_set_format": self.DEFAULT_TYPE, - "data_set_path": data_set_path, - "neighbors_data_set_path": neighbors_data_set_path, - "k": k, - "filter": filter_body, - } - query_param_source = VectorSearchPartitionParamSource( - workload.Workload(name="unit-test"), - test_param_source_params, { - "index": self.DEFAULT_INDEX_NAME, - "request-params": {}, - "body": { - "size": 100, - } - } - ) - query_param_source_partition = query_param_source.partition(0, 1) - - # Check each - for _ in range(DEFAULT_NUM_VECTORS): - self._check_params( - query_param_source_partition.params(), - self.DEFAULT_FIELD_NAME, - self.DEFAULT_DIMENSION, - k, - 100, - filter_body, - ) - - # Assert last call creates stop iteration - with self.assertRaises(StopIteration): - query_param_source_partition.params() def test_post_filter(self): # Create a data set k = 12 @@ -3434,7 +3378,7 @@ def test_params_efficient_filter( "data_set_path": data_set_path, "bulk_size": bulk_size, "id-field-name": self.DEFAULT_ID_FIELD_NAME, - "has_attributes": True + "filter_attributes": self.ATTRIBUTES_LIST } bulk_param_source = BulkVectorsFromDataSetParamSource( workload.Workload(name="unit-test"), test_param_source_params