Skip to content

Commit

Permalink
Address Vijay offline feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Roblin <[email protected]>
  • Loading branch information
finnroblin committed Aug 15, 2024
1 parent 7d99d80 commit e3821d9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 66 deletions.
8 changes: 8 additions & 0 deletions osbenchmark/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Context(Enum):
INDEX = 1
QUERY = 2
NEIGHBORS = 3
MAX_DISTANCE_NEIGHBORS = 4
MIN_SCORE_NEIGHBORS = 5
ATTRIBUTES = 7


Expand Down Expand Up @@ -142,6 +144,12 @@ def parse_context(context: Context) -> str:
if context == Context.QUERY:
return "test"

if context == Context.MAX_DISTANCE_NEIGHBORS:
return "max_distance_neighbors"

if context == Context.MIN_SCORE_NEIGHBORS:
return "min_score_neighbors"

if context == Context.ATTRIBUTES:
return "attributes"

Expand Down
20 changes: 11 additions & 9 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from osbenchmark import exceptions
from osbenchmark.utils import io
from osbenchmark.utils.dataset import DataSet, get_data_set, Context
from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter, parse_bool_parameter
from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter
from osbenchmark.workload import workload

__PARAM_SOURCES_BY_OP = {}
Expand Down Expand Up @@ -1127,9 +1127,9 @@ def _update_body_params(self, vector):
"[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)

self.logger.info("Here, we have query_params: %s ", self.query_params)
efficient_filter=self.query_params.get(self.PARAMS_NAME_FILTER)
filter_type=self.query_params.get(self.PARAMS_NAME_FILTER_TYPE)
filter_body=self.query_params.get(self.PARAMS_NAME_FILTER_BODY)
efficient_filter = filter_body if filter_type == "efficient" else None

# override query params with vector search query
body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector, efficient_filter, filter_type, filter_body)
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def __init__(self, workload, params, **kwargs):
self.id_field_name: str = parse_string_parameter(
self.PARAMS_NAME_ID_FIELD_NAME, params, self.DEFAULT_ID_FIELD_NAME
)
self.has_attributes = parse_bool_parameter("has_attributes", params, False)
self.filter_attributes: List[Any] = params.get("filter_attributes", [])

self.action_buffer = None
self.num_nested_vectors = 10
Expand Down Expand Up @@ -1294,7 +1294,7 @@ def partition(self, partition_index, total_partitions):
)
partition.parent_data_set.seek(partition.offset)

if self.has_attributes:
if self.filter_attributes:
partition.attributes_data_set = get_data_set(
self.parent_data_set_format, self.parent_data_set_path, Context.ATTRIBUTES
)
Expand All @@ -1317,8 +1317,10 @@ def bulk_transform_add_attributes(self, partition: np.ndarray, action, attribute
partition.tolist(), attributes.tolist(), range(self.current, self.current + len(partition))
):
row = {self.field_name: vec}
for idx, attribute_name, attribute_type in zip(range(3), ["taste", "color", "age"], [str, str, int]):
row.update({attribute_name : attribute_type(attribute_list[idx])})
for idx, attribute_name in zip(range(len(self.filter_attributes)), self.filter_attributes):
attribute = attribute_list[idx].decode()
if attribute != "None":
row.update({attribute_name : attribute})
if add_id_field_to_body:
row.update({self.id_field_name: identifier})
bulk_contents.append(row)
Expand Down Expand Up @@ -1369,11 +1371,11 @@ def bulk_transform(
An array of transformed vectors in bulk format.
"""

if not self.is_nested and not self.has_attributes:
if not self.is_nested and not self.filter_attributes:
return self.bulk_transform_non_nested(partition, action)

# TODO: Assumption: we won't add attributes if we're also doing a nested query.
if self.has_attributes:
if self.filter_attributes:
return self.bulk_transform_add_attributes(partition, action, attributes)
actions = []

Expand Down Expand Up @@ -1457,7 +1459,7 @@ def action(id_field_name, doc_id):
else:
parent_ids = None

if self.has_attributes:
if self.filter_attributes:
attributes = self.attributes_data_set.read(bulk_size)
else:
attributes = None
Expand Down
58 changes: 1 addition & 57 deletions tests/workload/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2900,62 +2900,6 @@ def test_params_default(self):
with self.assertRaises(StopIteration):
query_param_source_partition.params()

def test_params_custom_body(self):
# Create a data set
k = 12
data_set_path = create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.QUERY,
self.data_set_dir
)
neighbors_data_set_path = create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.NEIGHBORS,
self.data_set_dir
)
filter_body = {
"key": "value"
}

# Create a QueryVectorsFromDataSetParamSource with relevant params
test_param_source_params = {
"field": self.DEFAULT_FIELD_NAME,
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"k": k,
"filter": filter_body,
}
query_param_source = VectorSearchPartitionParamSource(
workload.Workload(name="unit-test"),
test_param_source_params, {
"index": self.DEFAULT_INDEX_NAME,
"request-params": {},
"body": {
"size": 100,
}
}
)
query_param_source_partition = query_param_source.partition(0, 1)

# Check each
for _ in range(DEFAULT_NUM_VECTORS):
self._check_params(
query_param_source_partition.params(),
self.DEFAULT_FIELD_NAME,
self.DEFAULT_DIMENSION,
k,
100,
filter_body,
)

# Assert last call creates stop iteration
with self.assertRaises(StopIteration):
query_param_source_partition.params()
def test_post_filter(self):
# Create a data set
k = 12
Expand Down Expand Up @@ -3434,7 +3378,7 @@ def test_params_efficient_filter(
"data_set_path": data_set_path,
"bulk_size": bulk_size,
"id-field-name": self.DEFAULT_ID_FIELD_NAME,
"has_attributes": True
"filter_attributes": self.ATTRIBUTES_LIST
}
bulk_param_source = BulkVectorsFromDataSetParamSource(
workload.Workload(name="unit-test"), test_param_source_params
Expand Down

0 comments on commit e3821d9

Please sign in to comment.