From 955d3389c44e1c7598daa6a479f2fbed8bacc508 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Fri, 2 Jun 2023 08:34:20 -0700 Subject: [PATCH] Added support for Efficient Pre-filtering for Faiss Engine. The changes include * Enabled the efficient filtering support for Faiss Engine (#907) * Fixed the ef_search default value for faiss HNSW with filters and updated the perf-tool to include Faiss HNSW tests (#926) * Added exact search for cases when filteredIds < k to improve the recall for exact search (#928) * Improved Exact Search to return only K results and added client side latency metric for query Benchmarks (#933) * Added Integration Tests and Unit test for Efficient Filtering for Faiss Engine (#934) Signed-off-by: Navneet Verma --- CHANGELOG.md | 3 +- DEVELOPER_GUIDE.md | 4 +- benchmarks/perf-tool/README.md | 47 +++-- .../perf-tool/okpt/io/config/parsers/test.py | 5 + .../perf-tool/okpt/io/config/schemas/test.yml | 3 + benchmarks/perf-tool/okpt/test/steps/steps.py | 22 ++- .../filtering/relaxed-filter/index.json | 26 +++ .../relaxed-filter/relaxed-filter-spec.json | 42 ++++ .../relaxed-filter/relaxed-filter-test.yml | 34 ++++ .../filtering/restrictive-filter/index.json | 26 +++ .../restrictive-filter-spec.json | 44 +++++ .../restrictive-filter-test.yml | 37 ++++ .../release-configs/faiss-hnsw/index.json | 26 +++ .../release-configs/faiss-hnsw/test.yml | 32 +++ .../relaxed-filter/relaxed-filter-test.yml | 4 +- .../restrictive-filter-test.yml | 17 +- jni/CMakeLists.txt | 4 +- jni/external/faiss | 2 +- jni/include/faiss_wrapper.h | 6 + .../org_opensearch_knn_jni_FaissService.h | 8 + jni/src/faiss_wrapper.cpp | 165 +++++++++++++++- .../org_opensearch_knn_jni_FaissService.cpp | 12 ++ .../opensearch/knn/index/query/KNNQuery.java | 33 ++++ .../knn/index/query/KNNQueryBuilder.java | 2 +- .../knn/index/query/KNNQueryFactory.java | 36 +++- .../opensearch/knn/index/query/KNNScorer.java | 33 ++++ .../opensearch/knn/index/query/KNNWeight.java | 185 ++++++++++++++++-- .../opensearch/knn/index/util/KNNEngine.java | 5 + .../opensearch/knn/index/util/KNNLibrary.java | 2 +- .../org/opensearch/knn/jni/FaissService.java | 4 +- .../org/opensearch/knn/jni/JNIService.java | 18 +- .../knn/plugin/rest/RestGetModelHandler.java | 4 +- .../knn/plugin/rest/RestKNNStatsHandler.java | 6 +- .../knn/plugin/rest/RestKNNWarmupHandler.java | 4 +- .../plugin/transport/DeleteModelRequest.java | 3 +- .../TrainingJobRouterTransportAction.java | 3 +- .../opensearch/knn/training/TrainingJob.java | 3 +- .../org/opensearch/knn/index/FaissIT.java | 94 +++++++++ .../opensearch/knn/index/LuceneEngineIT.java | 19 -- .../knn/index/codec/KNNCodecTestUtil.java | 2 +- .../memory/NativeMemoryLoadStrategyTests.java | 2 +- .../knn/index/query/KNNQueryFactoryTests.java | 35 +++- .../knn/index/query/KNNWeightTests.java | 168 +++++++++++++++- .../opensearch/knn/jni/JNIServiceTests.java | 20 +- .../org/opensearch/knn/KNNRestTestCase.java | 16 ++ 45 files changed, 1146 insertions(+), 120 deletions(-) create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/index.json create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-spec.json create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/index.json create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-spec.json create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/index.json create mode 100644 benchmarks/perf-tool/release-configs/faiss-hnsw/test.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index ea5cc8e8e0..e57e634be6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,9 +15,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.8...2.x) ### Features +* Added efficient filtering support for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936)) ### Enhancements ### Bug Fixes ### Infrastructure ### Documentation ### Maintenance -### Refactoring \ No newline at end of file +### Refactoring diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 4a6f360b51..d8ef9e4130 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -56,11 +56,11 @@ In addition to this, the plugin has been tested with JDK 17, and this JDK versio #### CMake -The plugin requires that cmake >= 3.17.2 is installed in order to build the JNI libraries. +The plugin requires that cmake >= 3.23.1 is installed in order to build the JNI libraries. One easy way to install on mac or linux is to use pip: ```bash -pip install cmake==3.17.2 +pip install cmake==3.23.1 ``` #### Faiss Dependencies diff --git a/benchmarks/perf-tool/README.md b/benchmarks/perf-tool/README.md index 9c1c189182..f98227e27f 100644 --- a/benchmarks/perf-tool/README.md +++ b/benchmarks/perf-tool/README.md @@ -13,18 +13,36 @@ file. ## Install Prerequisites -### Python +### Setup -Python 3.7 or above is required. +K-NN perf requires Python 3.8 or greater to be installed. One of +the easier ways to do this is through Conda, a package and environment +management system for Python. -### Pip +First, follow the +[installation instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) +to install Conda on your system. -Use pip to install the necessary requirements: +Next, create a Python 3.8 environment: +``` +conda create -n knn-perf python=3.8 +``` + +After the environment is created, activate it: +``` +source activate knn-perf +``` +Lastly, clone the k-NN repo and install all required python packages: ``` +git clone https://github.com/opensearch-project/k-NN.git +cd k-NN/benchmarks/perf-tool pip install -r requirements.txt ``` +After all of this completes, you should be ready to run your first performance benchmarks! + + ## Usage ### Quick Start @@ -72,16 +90,17 @@ The output will be the delta between the two metrics. ### Test Parameters -| Parameter Name | Description | Default | -| ----------- | ----------- | ----------- | -| endpoint | Endpoint OpenSearch cluster is running on | localhost | -| test_name | Name of test | No default | -| test_id | String ID of test | No default | -| num_runs | Number of runs to execute steps | 1 | -| show_runs | Whether to output each run in addition to the total summary | false | -| setup | List of steps to run once before metric collection starts | [] | -| steps | List of steps that make up one test run. Metrics will be collected on these steps. | No default | -| cleanup | List of steps to run after each test run | [] | +| Parameter Name | Description | Default | +|----------------|------------------------------------------------------------------------------------|------------| +| endpoint | Endpoint OpenSearch cluster is running on | localhost | +| port | Port on which OpenSearch Cluster is running on | 9200 | +| test_name | Name of test | No default | +| test_id | String ID of test | No default | +| num_runs | Number of runs to execute steps | 1 | +| show_runs | Whether to output each run in addition to the total summary | false | +| setup | List of steps to run once before metric collection starts | [] | +| steps | List of steps that make up one test run. Metrics will be collected on these steps. | No default | +| cleanup | List of steps to run after each test run | [] | ### Steps diff --git a/benchmarks/perf-tool/okpt/io/config/parsers/test.py b/benchmarks/perf-tool/okpt/io/config/parsers/test.py index 34b1752c72..d0ef4c02fe 100644 --- a/benchmarks/perf-tool/okpt/io/config/parsers/test.py +++ b/benchmarks/perf-tool/okpt/io/config/parsers/test.py @@ -23,6 +23,7 @@ class TestConfig: test_name: str test_id: str endpoint: str + port: int num_runs: int show_runs: bool setup: List[Step] @@ -48,6 +49,9 @@ def parse(self, file_obj: TextIOWrapper) -> TestConfig: if 'endpoint' in config_obj: implicit_step_config['endpoint'] = config_obj['endpoint'] + if 'port' in config_obj: + implicit_step_config['port'] = config_obj['port'] + # Each step should have its own parse - take the config object and check if its valid setup = [] if 'setup' in config_obj: @@ -62,6 +66,7 @@ def parse(self, file_obj: TextIOWrapper) -> TestConfig: test_config = TestConfig( endpoint=config_obj['endpoint'], + port=config_obj['port'], test_name=config_obj['test_name'], test_id=config_obj['test_id'], num_runs=config_obj['num_runs'], diff --git a/benchmarks/perf-tool/okpt/io/config/schemas/test.yml b/benchmarks/perf-tool/okpt/io/config/schemas/test.yml index 1939a8a311..06b880cc75 100644 --- a/benchmarks/perf-tool/okpt/io/config/schemas/test.yml +++ b/benchmarks/perf-tool/okpt/io/config/schemas/test.yml @@ -9,6 +9,9 @@ endpoint: type: string default: "localhost" +port: + type: integer + default: 9200 test_name: type: string test_id: diff --git a/benchmarks/perf-tool/okpt/test/steps/steps.py b/benchmarks/perf-tool/okpt/test/steps/steps.py index 0de61078fc..cc1773330b 100644 --- a/benchmarks/perf-tool/okpt/test/steps/steps.py +++ b/benchmarks/perf-tool/okpt/test/steps/steps.py @@ -5,7 +5,7 @@ # compatible open source license. """Provides steps for OpenSearch tests. -Some of the OpenSearch operations return a `took` field in the response body, +Some OpenSearch operations return a `took` field in the response body, so the profiling decorators aren't needed for some functions. """ import json @@ -454,8 +454,10 @@ def _action(self): results['took'] = [ float(query_response['took']) for query_response in query_responses ] - port = 9200 if self.endpoint == 'localhost' else 80 - results['memory_kb'] = get_cache_size_in_kb(self.endpoint, port) + results['client_time'] = [ + float(query_response['client_time']) for query_response in query_responses + ] + results['memory_kb'] = get_cache_size_in_kb(self.endpoint, self.port) if self.calculate_recall: ids = [[int(hit['_id']) @@ -473,7 +475,7 @@ def _action(self): return results def _get_measures(self) -> List[str]: - measures = ['took', 'memory_kb'] + measures = ['took', 'memory_kb', 'client_time'] if self.calculate_recall: measures.extend(['recall@K', f'recall@{str(self.r)}']) @@ -614,7 +616,6 @@ def _action(self): num_of_search_segments = 0; for shard_key in shards.keys(): for segment in shards[shard_key]: - num_of_committed_segments += segment["num_committed_segments"] num_of_search_segments += segment["num_search_segments"] @@ -689,12 +690,13 @@ def delete_model(endpoint, port, model_id): return response.json() -def get_opensearch_client(endpoint: str, port: int): +def get_opensearch_client(endpoint: str, port: int, timeout=60): """ Get an opensearch client from an endpoint and port Args: endpoint: Endpoint OpenSearch is running on port: Port OpenSearch is running on + timeout: timeout for OpenSearch client, default value 60 Returns: OpenSearch client @@ -708,7 +710,7 @@ def get_opensearch_client(endpoint: str, port: int): use_ssl=False, verify_certs=False, connection_class=RequestsHttpConnection, - timeout=60, + timeout=timeout, ) @@ -784,9 +786,13 @@ def get_cache_size_in_kb(endpoint, port): def query_index(opensearch: OpenSearch, index_name: str, body: dict, excluded_fields: list): - return opensearch.search(index=index_name, + start_time = round(time.time()*1000) + queryResponse = opensearch.search(index=index_name, body=body, _source_excludes=excluded_fields) + end_time = round(time.time() * 1000) + queryResponse['client_time'] = end_time - start_time + return queryResponse def bulk_index(opensearch: OpenSearch, index_name: str, body: List): diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/index.json b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/index.json new file mode 100644 index 0000000000..b8f591176c --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/index.json @@ -0,0 +1,26 @@ +{ + "settings": { + "index": { + "knn": true, + "number_of_shards": 24, + "number_of_replicas": 1 + } + }, + "mappings": { + "properties": { + "target_field": { + "type": "knn_vector", + "dimension": 128, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "faiss", + "parameters": { + "ef_construction": 256, + "m": 16 + } + } + } + } + } +} diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-spec.json b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-spec.json new file mode 100644 index 0000000000..fecde03928 --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-spec.json @@ -0,0 +1,42 @@ +{ + "bool": + { + "should": + [ + { + "range": + { + "age": + { + "gte": 30, + "lte": 70 + } + } + }, + { + "term": + { + "color": "green" + } + }, + { + "term": + { + "color": "blue" + } + }, + { + "term": + { + "color": "yellow" + } + }, + { + "term": + { + "color": "sweet" + } + } + ] + } +} \ No newline at end of file diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml new file mode 100644 index 0000000000..61486b3b60 --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml @@ -0,0 +1,34 @@ +endpoint: [ENDPOINT] +test_name: "Faiss HNSW Relaxed Filter Test" +test_id: "Faiss HNSW Relaxed Filter Test" +num_runs: 10 +show_runs: false +steps: + - name: delete_index + index_name: target_index + - name: create_index + index_name: target_index + index_spec: [INDEX_SPEC_PATH]/relaxed-filter/index.json + - name: ingest_multi_field + index_name: target_index + field_name: target_field + bulk_size: 500 + dataset_format: hdf5 + dataset_path: [DATASET_PATH]/sift-128-euclidean-with-attr.hdf5 + attributes_dataset_name: attributes + attribute_spec: [ { name: 'color', type: 'str' }, { name: 'taste', type: 'str' }, { name: 'age', type: 'int' } ] + - name: refresh_index + index_name: target_index + - name: query_with_filter + k: 100 + r: 1 + calculate_recall: true + index_name: target_index + field_name: target_field + dataset_format: hdf5 + dataset_path: [DATASET_PATH]/sift-128-euclidean-with-attr.hdf5 + neighbors_format: hdf5 + neighbors_path: [DATASET_PATH]/sift-128-euclidean-with-filters.hdf5 + neighbors_dataset: neighbors_filter_5 + filter_spec: [INDEX_SPEC_PATH]/relaxed-filter-spec.json + filter_type: FILTER diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/index.json b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/index.json new file mode 100644 index 0000000000..b8f591176c --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/index.json @@ -0,0 +1,26 @@ +{ + "settings": { + "index": { + "knn": true, + "number_of_shards": 24, + "number_of_replicas": 1 + } + }, + "mappings": { + "properties": { + "target_field": { + "type": "knn_vector", + "dimension": 128, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "faiss", + "parameters": { + "ef_construction": 256, + "m": 16 + } + } + } + } + } +} diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-spec.json b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-spec.json new file mode 100644 index 0000000000..9e6356f1c7 --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-spec.json @@ -0,0 +1,44 @@ +{ + "bool": + { + "must": + [ + { + "range": + { + "age": + { + "gte": 30, + "lte": 60 + } + } + }, + { + "term": + { + "taste": "bitter" + } + }, + { + "bool": + { + "should": + [ + { + "term": + { + "color": "blue" + } + }, + { + "term": + { + "color": "green" + } + } + ] + } + } + ] + } +} \ No newline at end of file diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml new file mode 100644 index 0000000000..bf02144ac5 --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml @@ -0,0 +1,37 @@ +endpoint: [ENDPOINT] +test_name: "Faiss HNSW Restrictive Filter Test" +test_id: "Faiss HNSW Restrictive Filter Test" +num_runs: 10 +show_runs: false +steps: + - name: delete_index + index_name: target_index + - name: create_index + index_name: target_index + index_spec: [INDEX_SPEC_PATH]/index.json + - name: ingest_multi_field + index_name: target_index + field_name: target_field + bulk_size: 500 + dataset_format: hdf5 + dataset_path: [DATASET_PATH]/sift-128-euclidean-with-attr.hdf5 + attributes_dataset_name: attributes + attribute_spec: [ { name: 'color', type: 'str' }, { name: 'taste', type: 'str' }, { name: 'age', type: 'int' } ] + - name: refresh_index + index_name: target_index + - name: force_merge + index_name: target_index + max_num_segments: 1 + - name: query_with_filter + k: 100 + r: 1 + calculate_recall: true + index_name: target_index + field_name: target_field + dataset_format: hdf5 + dataset_path: [DATASET_PATH]/sift-128-euclidean-with-attr.hdf5 + neighbors_format: hdf5 + neighbors_path: [DATASET_PATH]/sift-128-euclidean-with-filters.hdf5 + neighbors_dataset: neighbors_filter_4 + filter_spec: [INDEX_SPEC_PATH]/restrictive-filter-spec.json + filter_type: FILTER diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/index.json b/benchmarks/perf-tool/release-configs/faiss-hnsw/index.json new file mode 100644 index 0000000000..b8f591176c --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/index.json @@ -0,0 +1,26 @@ +{ + "settings": { + "index": { + "knn": true, + "number_of_shards": 24, + "number_of_replicas": 1 + } + }, + "mappings": { + "properties": { + "target_field": { + "type": "knn_vector", + "dimension": 128, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "faiss", + "parameters": { + "ef_construction": 256, + "m": 16 + } + } + } + } + } +} diff --git a/benchmarks/perf-tool/release-configs/faiss-hnsw/test.yml b/benchmarks/perf-tool/release-configs/faiss-hnsw/test.yml new file mode 100644 index 0000000000..f3e976cf3c --- /dev/null +++ b/benchmarks/perf-tool/release-configs/faiss-hnsw/test.yml @@ -0,0 +1,32 @@ +endpoint: localhost +test_name: "Faiss HNSW Test" +test_id: "Faiss HNSW Test" +num_runs: 10 +show_runs: false +steps: + - name: delete_index + index_name: target_index + - name: create_index + index_name: target_index + index_spec: /home/ec2-user/[PATH]/index.json + - name: ingest + index_name: target_index + field_name: target_field + bulk_size: 500 + dataset_format: hdf5 + dataset_path: [DATASET_PATH]/sift-128-euclidean.hdf5 + - name: refresh_index + index_name: target_index + - name: force_merge + index_name: target_index + max_num_segments: 1 + - name: query + k: 100 + r: 1 + calculate_recall: true + index_name: target_index + field_name: target_field + dataset_format: hdf5 + dataset_path: [DATASET_PATH]/sift-128-euclidean.hdf5 + neighbors_format: hdf5 + neighbors_path: [DATASET_PATH]/sift-128-euclidean.hdf5 diff --git a/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml b/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml index f20fba2031..44ed8e66e2 100644 --- a/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml +++ b/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/relaxed-filter/relaxed-filter-test.yml @@ -1,6 +1,6 @@ endpoint: [ENDPOINT] -test_name: "index-workflow" -test_id: "Index workflow" +test_name: "Lucene HNSW Relaxed Filter Test" +test_id: "Lucene HNSW Relaxed Filter Test" num_runs: 10 show_runs: false steps: diff --git a/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml b/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml index b1d7b60d7b..d7f451a48e 100644 --- a/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml +++ b/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/restrictive-filter/restrictive-filter-test.yml @@ -1,6 +1,6 @@ endpoint: [ENDPOINT] -test_name: "index-workflow" -test_id: "Index workflow" +test_name: "Lucene HNSW Restrictive Filter Test" +test_id: "Lucene HNSW Restrictive Filter Test" num_runs: 10 show_runs: false steps: @@ -8,17 +8,20 @@ steps: index_name: target_index - name: create_index index_name: target_index - index_spec: [INDEX_SPEC_PATH]/index.json + index_spec: /home/ec2-user/k-NN/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/restrictive-filter/index.json - name: ingest_multi_field index_name: target_index field_name: target_field bulk_size: 500 dataset_format: hdf5 - dataset_path: [DATASET_PATH]/sift-128-euclidean-with-attr.hdf5 + dataset_path: /home/ec2-user/k-NN/benchmarks/perf-tool/dataset/sift-128-euclidean-with-attr.hdf5 attributes_dataset_name: attributes attribute_spec: [ { name: 'color', type: 'str' }, { name: 'taste', type: 'str' }, { name: 'age', type: 'int' } ] - name: refresh_index index_name: target_index + - name: force_merge + index_name: target_index + max_num_segments: 1 - name: query_with_filter k: 100 r: 1 @@ -26,9 +29,9 @@ steps: index_name: target_index field_name: target_field dataset_format: hdf5 - dataset_path: [DATASET_PATH]/sift-128-euclidean-with-attr.hdf5 + dataset_path: /home/ec2-user/k-NN/benchmarks/perf-tool/dataset/sift-128-euclidean-with-attr.hdf5 neighbors_format: hdf5 - neighbors_path: [DATASET_PATH]/sift-128-euclidean-with-filters.hdf5 + neighbors_path: /home/ec2-user/k-NN/benchmarks/perf-tool/dataset/sift-128-euclidean-with-filters.hdf5 neighbors_dataset: neighbors_filter_4 - filter_spec: [INDEX_SPEC_PATH]/restrictive-filter-spec.json + filter_spec: /home/ec2-user/k-NN/benchmarks/perf-tool/release-configs/lucene-hnsw/filtering/restrictive-filter/restrictive-filter-spec.json filter_type: FILTER diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 668ce684d9..29a844ee07 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -cmake_minimum_required(VERSION 3.17) +cmake_minimum_required(VERSION 3.23.1) project(KNNPlugin_JNI) @@ -95,7 +95,7 @@ if (${CONFIG_NMSLIB} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES SUFFIX ${LIB_EXT}) set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES POSITION_INDEPENDENT_CODE ON) - if (WIN32) + if (NOT "${WIN32}" STREQUAL "") # Use RUNTIME_OUTPUT_DIRECTORY, to build the target library (opensearchknn_nmslib) in the specified directory at runtime. set_target_properties(${TARGET_LIB_NMSLIB} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/release) else() diff --git a/jni/external/faiss b/jni/external/faiss index 88eabe97f9..3219e3d12e 160000 --- a/jni/external/faiss +++ b/jni/external/faiss @@ -1 +1 @@ -Subproject commit 88eabe97f96d0c0964dfa075f74373c64d46da80 +Subproject commit 3219e3d12e6fc36dfdfe17d4cf238ef70bf89568 diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 6c8a861435..284214631f 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -40,6 +40,12 @@ namespace knn_jni { jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ); + // Execute a query against the index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 1ab6c56817..a252643355 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -50,6 +50,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex (JNIEnv *, jclass, jlong, jfloatArray, jint); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: queryIndex_WithFilter + * Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult; + */ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter + (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray); + /* * Class: org_opensearch_knn_jni_FaissService * Method: free diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index e0fcc822b9..a1bbb96352 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -18,12 +18,18 @@ #include "faiss/IndexHNSW.h" #include "faiss/IndexIVFFlat.h" #include "faiss/MetaIndexes.h" +#include "faiss/Index.h" +#include "faiss/impl/IDSelector.h" #include #include #include #include +// Defines type of IDSelector +enum FilterIdsSelectorType{ + BITMAP, BATCH +}; // Translate space type to faiss metric faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType); @@ -33,7 +39,19 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, const std::unordered_map& parametersCpp, faiss::Index * index); // Train an index with data provided -void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x); +void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); + +// Create the SearchParams based on the Index Type +std::unique_ptr buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector); + +// Helps to choose the right FilterIdsSelectorType for Faiss +FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength); + +// Converts the int FilterIds to Faiss ids type array. +void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); + +// Concerts the FilterIds to BitMap +void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector); void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { @@ -181,12 +199,17 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ) { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr); +} + +jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } - auto *indexReader = reinterpret_cast(indexPointerJ); + auto *indexReader = reinterpret_cast(indexPointerJ); if (indexReader == nullptr) { throw std::runtime_error("Invalid pointer to index"); @@ -195,14 +218,49 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from // the query point std::vector dis(kJ); - std::vector ids(kJ); + std::vector ids(kJ); float* rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); - - try { - indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data()); - } catch (...) { - jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - throw; + // create the filterSearch params if the filterIdsJ is not a null pointer + if(filterIdsJ != nullptr) { + int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = env->GetArrayLength(filterIdsJ); + std::unique_ptr idSelector; + FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength); + // start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data + // during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local + // vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers. + // To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/ + std::vector convertedIds; + std::vector bitmap; + // Choose a selector which suits best + if(idSelectorType == BATCH) { + convertedIds.resize(filterIdsLength); + convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data()); + idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data())); + } else { + int maxIdValue = filteredIdsArray[filterIdsLength - 1]; + // >> 3 is equivalent to value / 8 + const int bitsetArraySize = (maxIdValue >> 3) + 1; + bitmap.resize(bitsetArraySize, 0); + buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data()); + idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data())); + } + std::unique_ptr searchParameters = buildSearchParams(indexReader, idSelector.get()); + try { + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters.get()); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + throw; + } + jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + } else { + try { + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data()); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + throw; + } } jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); @@ -344,7 +402,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, } } -void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x) { +void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { if (auto * indexIvf = dynamic_cast(index)) { if (indexIvf->quantizer_trains_alone == 2) { InternalTrainIndex(indexIvf->quantizer, n, x); @@ -356,3 +414,90 @@ void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float index->train(n, x); } } + +/** + * This function takes a call on what ID Selector to use: + * https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap + * + * class storage lookup construction(Opensearch + Faiss) + * IDSelectorArray O(k) O(k) O(2k) + * IDSelectorBatch O(k) O(1) O(2k) + * IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index + * + * TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts: + * an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to + * 7.5M vectors. + * Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation + * M = 16 + * Dimension = 128 + * (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB + * Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be + * 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease. + * + * With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids + * So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size + * as factor. We need to improve on this. + * + * TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss. + * + * @param filterIds + * @param filterIdsLength + * @return std::string + */ +FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) { + int maxIdValue = filterIds[filterIdsLength - 1]; + if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) { + return BATCH; + } + return BITMAP; +} + +void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) { + for (int i = 0; i < filterIdsLength; i++) { + convertedFilterIds[i] = filterIds[i]; + } +} + +void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) { + /** + * Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected + * iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1. + */ + for(int i = 0 ; i < filterIdsLength ; i ++) { + int value = filterIds[i]; + // / , % are expensive operation. Hence, using BitShift operation as they are fast. + int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8 + // (value & 7) equivalent to value % 8 + bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7)); + } +} + +/** + * Based on the type of the index reader we need to return the SearchParameters. The way we do this by dynamically + * casting the IndexReader. + * @param indexReader + * @param idSelector + * @return SearchParameters + */ +std::unique_ptr buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector) { + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // we need to make this variable unique_ptr so that the scope can be shared with caller function. + std::unique_ptr hnswParams(new faiss::SearchParametersHNSW); + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams->efSearch = hnswReader->hnsw.efSearch; + hnswParams->sel = idSelector; + return hnswParams; + } + + auto ivfReader = dynamic_cast(indexReader->index); + auto ivfFlatReader = dynamic_cast(indexReader->index); + if(ivfReader || ivfFlatReader) { + // we need to make this variable unique_ptr so that the scope can be shared with caller function. + std::unique_ptr ivfParams(new faiss::SearchParametersIVF); + ivfParams->sel = idSelector; + return ivfParams; + } + throw std::runtime_error("Invalid Index Type supported for Filtered Search on Faiss"); +} diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 543ce8ec49..1b79d91143 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -88,6 +88,18 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd return nullptr; } +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ) { + + try { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; + +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) { try { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 9bf38008bb..5ac207c431 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -5,6 +5,11 @@ package org.opensearch.knn.index.query; +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -25,6 +30,10 @@ public class KNNQuery extends Query { private final int k; private final String indexName; + @Getter + @Setter + private Query filterQuery; + public KNNQuery(String field, float[] queryVector, int k, String indexName) { this.field = field; this.queryVector = queryVector; @@ -32,6 +41,14 @@ public KNNQuery(String field, float[] queryVector, int k, String indexName) { this.indexName = indexName; } + public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) { + this.field = field; + this.queryVector = queryVector; + this.k = k; + this.indexName = indexName; + this.filterQuery = filterQuery; + } + public String getField() { return this.field; } @@ -61,9 +78,25 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo if (!KNNSettings.isKNNPluginEnabled()) { throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true"); } + final Weight filterWeight = getFilterWeight(searcher); + if (filterWeight != null) { + return new KNNWeight(this, boost, filterWeight); + } return new KNNWeight(this, boost); } + private Weight getFilterWeight(IndexSearcher searcher) throws IOException { + if (this.getFilterQuery() != null) { + // Run the filter query + final BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.getFilterQuery(), BooleanClause.Occur.FILTER) + .add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER) + .build(); + final Query rewritten = searcher.rewrite(booleanQuery); + return searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } + return null; + } + @Override public void visit(QueryVisitor visitor) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 5e04f4afe1..3a8319a730 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -252,7 +252,7 @@ protected Query doToQuery(QueryShardContext context) { ); } - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) { + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && knnEngine != KNNEngine.FAISS) { throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 188bbc150a..20c456c4a5 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -59,27 +59,53 @@ public static Query create(CreateQueryRequest createQueryRequest) { final String fieldName = createQueryRequest.getFieldName(); final int k = createQueryRequest.getK(); final float[] vector = createQueryRequest.getVector(); + final Query filterQuery = getFilterQuery(createQueryRequest); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { + log.debug( + String.format( + "Creating custom k-NN query with filters for index: %s \"\", field: %s \"\", " + "k: %d", + indexName, + fieldName, + k + ) + ); + return new KNNQuery(fieldName, vector, k, indexName, filterQuery); + } log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); return new KNNQuery(fieldName, vector, k, indexName); } + if (filterQuery != null) { + log.debug( + String.format("Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k) + ); + return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + } + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KnnFloatVectorQuery(fieldName, vector, k); + } + + private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { if (createQueryRequest.getFilter().isPresent()) { final QueryShardContext queryShardContext = createQueryRequest.getContext() .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); log.debug( - String.format("Creating Lucene k-NN query with filter for index [%s], field [%s] and k [%d]", indexName, fieldName, k) + String.format( + "Creating k-NN query with filter for index [%s], field [%s] and k [%d]", + createQueryRequest.getIndexName(), + createQueryRequest.fieldName, + createQueryRequest.k + ) ); try { - final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + return createQueryRequest.getFilter().get().toQuery(queryShardContext); } catch (IOException e) { throw new RuntimeException("Cannot create knn query with filter", e); } } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnFloatVectorQuery(fieldName, vector, k); + return null; } /** diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 0005212bf6..3e5c8fff6f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -56,4 +56,37 @@ public float score() { public int docID() { return docIdsIter.docID(); } + + /** + * Returns the Empty Scorer implementation. We use this scorer to short circuit the actual search when it is not + * required. + * @param knnWeight {@link KNNWeight} + * @return {@link KNNScorer} + */ + public static Scorer emptyScorer(KNNWeight knnWeight) { + return new Scorer(knnWeight) { + private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); + + @Override + public DocIdSetIterator iterator() { + return docIdsIter; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return 0; + } + + @Override + public float score() throws IOException { + assert docID() != DocIdSetIterator.NO_MORE_DOCS; + return 0; + } + + @Override + public int docID() { + return docIdsIter.docID(); + } + }; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 716aed412a..b8b88b4fea 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,16 +5,27 @@ package org.opensearch.knn.index.query; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.search.FilteredDocIdSetIterator; +import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.util.KNNEngine; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -31,10 +42,12 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.plugin.stats.KNNCounter; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -49,20 +62,30 @@ /** * Calculate query weights and build query scorers. */ +@Log4j2 public class KNNWeight extends Weight { - private static Logger logger = LogManager.getLogger(KNNWeight.class); private static ModelDao modelDao; private final KNNQuery knnQuery; private final float boost; - private NativeMemoryCacheManager nativeMemoryCacheManager; + private final NativeMemoryCacheManager nativeMemoryCacheManager; + private final Weight filterWeight; public KNNWeight(KNNQuery query, float boost) { super(query); this.knnQuery = query; this.boost = boost; this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); + this.filterWeight = null; + } + + public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { + super(query); + this.knnQuery = query; + this.boost = boost; + this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); + this.filterWeight = filterWeight; } public static void initialize(ModelDao modelDao) { @@ -76,13 +99,91 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { + final int[] filterIdsArray = getFilterIdsArray(context); + // We don't need to go to JNI layer if no documents are found which satisfy the filters + // We should give this condition a deeper look that where it should be placed. For now I feel this is a good + // place, + if (filterWeight != null && filterIdsArray.length == 0) { + return KNNScorer.emptyScorer(this); + } + final Map docIdsToScoreMap = new HashMap<>(); + + /* + * The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph + * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. + * This improves the recall. + */ + if (filterWeight != null && filterIdsArray.length <= knnQuery.getK()) { + docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); + } else { + final Map annResults = doANNSearch(context, filterIdsArray); + if (annResults == null) { + return null; + } + docIdsToScoreMap.putAll(annResults); + } + if (docIdsToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(this); + } + return convertSearchResponseToScorer(docIdsToScoreMap); + } + + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException { + final Bits liveDocs = ctx.reader().getLiveDocs(); + final int maxDoc = ctx.reader().maxDoc(); + + final Scorer scorer = filterWeight.scorer(ctx); + if (scorer == null) { + return new FixedBitSet(0); + } + + return createBitSet(scorer.iterator(), liveDocs, maxDoc); + } + + private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException { + if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return ((BitSetIterator) filteredDocIdsIterator).getBitSet(); + } + // Create a new BitSet from matching and live docs + FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + return BitSet.of(filterIterator, maxDoc); + } + + private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException { + if (filterWeight == null) { + return new int[0]; + } + final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight); + final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; + int filteredIdsIndex = 0; + int docId = 0; + while (docId < filteredDocsBitSet.length()) { + docId = filteredDocsBitSet.nextSetBit(docId); + if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + log.debug("Docs in filtered docs id set is : {}", docId); + filteredIds[filteredIdsIndex] = docId; + filteredIdsIndex++; + docId++; + } + return filteredIds; + } + + private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); if (fieldInfo == null) { - logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); return null; } @@ -121,7 +222,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { .collect(Collectors.toList()); if (engineFiles.isEmpty()) { - logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + log.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); return null; } @@ -148,7 +249,6 @@ public Scorer scorer(LeafReaderContext context) throws IOException { // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); - try { if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); @@ -158,8 +258,10 @@ public Scorer scorer(LeafReaderContext context) throws IOException { indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), - knnEngine.getName() + knnEngine.getName(), + filterIdsArray ); + } catch (Exception e) { GRAPH_QUERY_ERRORS.increment(); throw new RuntimeException(e); @@ -174,21 +276,70 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * neighbors we are inverting the scores. */ if (results.length == 0) { - logger.debug("[KNN] Query yielded 0 results"); + log.debug("[KNN] Query yielded 0 results"); return null; } - Map scores = Arrays.stream(results) + return Arrays.stream(results) .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); - int maxDoc = Collections.max(scores.keySet()) + 1; - DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + } + + private Map doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) { + final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + float[] queryVector = this.knnQuery.getQueryVector(); + try { + final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); + final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE)); + // Creating min heap and init with MAX DocID and Score as -INF. + final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); + ScoreDoc topDoc = queue.top(); + final Map docToScore = new HashMap<>(); + for (int filterId : filterIdsArray) { + int docId = values.advance(filterId); + final BytesRef value = values.binaryValue(); + final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector); + if (score > topDoc.score) { + topDoc.score = score; + topDoc.doc = docId; + // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we + // have seen till now on top. + topDoc = queue.updateTop(); + } + } + // If scores are negative we will remove them. + // This is done, because there can be negative values in the Heap as we init the heap with Score as -INF. + // If filterIds < k, the some values in heap can have a negative score. + while (queue.size() > 0 && queue.top().score < 0) { + queue.pop(); + } + + while (queue.size() > 0) { + final ScoreDoc doc = queue.pop(); + docToScore.put(doc.doc, doc.score); + } + + return docToScore; + } catch (Exception e) { + log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery, e); + } + return Collections.emptyMap(); + } + private Scorer convertSearchResponseToScorer(final Map docsToScore) throws IOException { + final int maxDoc = Collections.max(docsToScore.keySet()) + 1; + final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); // The docIdSetIterator will contain the docids of the returned results. So, before adding results to - // the builder, we can grow to results.length - DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(results.length); - Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); - DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); - return new KNNScorer(this, docIdSetIter, scores, boost); + // the builder, we can grow to docsToScore.size() + final DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(docsToScore.size()); + docsToScore.keySet().forEach(setAdder::add); + final DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); + return new KNNScorer(this, docIdSetIter, docsToScore, boost); } @Override diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index fe28de43eb..776ea53668 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -32,6 +32,7 @@ public enum KNNEngine implements KNNLibrary { public static final KNNEngine DEFAULT = NMSLIB; private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); + private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, @@ -105,6 +106,10 @@ public static Set getEnginesThatCreateCustomSegmentFiles() { return CUSTOM_SEGMENT_FILE_ENGINES; } + public static Set getEnginesThatSupportsFilters() { + return ENGINES_SUPPORTING_FILTERS; + } + /** * Return number of max allowed dimensions per single vector based on the knn engine * @param knnEngine knn engine to check max dimensions value diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index b990ce33b4..ba1d3ac840 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -122,6 +122,6 @@ public interface KNNLibrary { * @return list of file extensions that will be read/write with mmap */ default List mmapFileExtensions() { - return Collections.EMPTY_LIST; + return Collections.emptyList(); } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index f1d869bd28..5dce15d6e0 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -24,7 +24,7 @@ * * In order to compile C++ header file, run: * javac -h jni/include src/main/java/org/opensearch/knn/jni/FaissService.java - * src/main/java/org/opensearch/knn/index/KNNQueryResult.java + * src/main/java/org/opensearch/knn/index/query/KNNQueryResult.java * src/main/java/org/opensearch/knn/common/KNNConstants.java */ class FaissService { @@ -83,6 +83,8 @@ public static native void createIndexFromTemplate( */ public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k); + public static native KNNQueryResult[] queryIndexWithFilter(long indexPointer, float[] queryVector, int k, int[] filterIds); + /** * Free native memory pointer */ diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index e32880fff4..f45fb0c736 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -11,6 +11,7 @@ package org.opensearch.knn.jni; +import org.apache.commons.lang.ArrayUtils; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -94,20 +95,27 @@ public static long loadIndex(String indexPath, Map parameters, S * Query an index * * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param engineName name of engine to query index + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param engineName name of engine to query index + * @param filteredIds array of ints on which should be used for search. * @return KNNQueryResult array of k neighbors */ - public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName) { + public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { return NmslibService.queryIndex(indexPointer, queryVector, k); } if (KNNEngine.FAISS.getName().equals(engineName)) { + // This code assumes that if filteredIds == null / filteredIds.length == 0 if filter is specified then empty + // k-NN results are already returned. Otherwise, it's a filter case and we need to run search with + // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine + // normally. + if (ArrayUtils.isNotEmpty(filteredIds)) { + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds); + } return FaissService.queryIndex(indexPointer, queryVector, k); } - throw new IllegalArgumentException("QueryIndex not supported for provided engine"); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java index a4afba5132..8b1f0676b7 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java @@ -12,8 +12,8 @@ package org.opensearch.knn.plugin.rest; import com.google.common.collect.ImmutableList; +import org.apache.commons.lang.StringUtils; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.Strings; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.GetModelAction; import org.opensearch.knn.plugin.transport.GetModelRequest; @@ -50,7 +50,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String modelID = restRequest.param(MODEL_ID); - if (!Strings.hasText(modelID)) { + if (StringUtils.isBlank(modelID)) { throw new IllegalArgumentException("model ID cannot be empty"); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java index 7aa349c616..9049a83db7 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java @@ -6,12 +6,12 @@ package org.opensearch.knn.plugin.rest; import lombok.AllArgsConstructor; +import org.apache.commons.lang.StringUtils; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.KNNStatsAction; import org.opensearch.knn.plugin.transport.KNNStatsRequest; import com.google.common.collect.ImmutableList; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; @@ -83,7 +83,7 @@ private KNNStatsRequest getRequest(RestRequest request) { // parse the nodes the user wants to query String[] nodeIdsArr = null; String nodesIdsStr = request.param("nodeId"); - if (!Strings.isEmpty(nodesIdsStr)) { + if (StringUtils.isNotEmpty(nodesIdsStr)) { nodeIdsArr = nodesIdsStr.split(","); } @@ -93,7 +93,7 @@ private KNNStatsRequest getRequest(RestRequest request) { // parse the stats the customer wants to see Set statsSet = null; String statsStr = request.param("stat"); - if (!Strings.isEmpty(statsStr)) { + if (StringUtils.isNotEmpty(statsStr)) { statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java index 7b78998c57..a31c2f2974 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.rest; +import org.apache.commons.lang.StringUtils; import org.opensearch.knn.common.exception.KNNInvalidIndicesException; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.KNNWarmupAction; @@ -15,7 +16,6 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.index.Index; import org.opensearch.rest.BaseRestHandler; @@ -81,7 +81,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } private KNNWarmupRequest createKNNWarmupRequest(RestRequest request) { - String[] indexNames = Strings.splitStringByCommaToArray(request.param("index")); + String[] indexNames = StringUtils.split(request.param("index"), ","); Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames); List invalidIndexNames = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java index 4c034a282e..da6e2990e4 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java @@ -11,6 +11,7 @@ package org.opensearch.knn.plugin.transport; +import org.apache.commons.lang.StringUtils; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.StreamInput; @@ -43,7 +44,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { - if (Strings.hasText(modelID)) { + if (StringUtils.isNotBlank(modelID)) { return null; } return addValidationError("Model id cannot be empty ", null); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index da925097c2..cce7a4dc07 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -11,6 +11,7 @@ package org.opensearch.knn.plugin.transport; +import org.apache.commons.lang.StringUtils; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.search.SearchRequest; @@ -108,7 +109,7 @@ protected DiscoveryNode selectNode(String preferredNode, TrainingJobRouteDecisio if (response.getTrainingJobCount() < 1) { selectedNode = currentNode; // Return right away if the user didnt pass a preferred node or this is the preferred node - if (Strings.isEmpty(preferredNode) || selectedNode.getId().equals(preferredNode)) { + if (StringUtils.isEmpty(preferredNode) || selectedNode.getId().equals(preferredNode)) { return selectedNode; } } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 8ae8dd49a5..3869e4994e 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -11,6 +11,7 @@ package org.opensearch.knn.training; +import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.UUIDs; @@ -68,7 +69,7 @@ public TrainingJob( String description ) { // Generate random base64 string if one is not provided - this.modelId = Strings.hasText(modelId) ? modelId : UUIDs.randomBase64UUID(); + this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); this.knnMethodContext = Objects.requireNonNull(knnMethodContext, "MethodContext cannot be null."); this.nativeMemoryCacheManager = Objects.requireNonNull(nativeMemoryCacheManager, "NativeMemoryCacheManager cannot be null."); this.trainingDataEntryContext = Objects.requireNonNull(trainingDataEntryContext, "TrainingDataEntryContext cannot be null."); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index ece41d01be..fcdbdcead2 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -12,11 +12,14 @@ package org.opensearch.knn.index; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Floats; +import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentFactory; @@ -43,6 +46,11 @@ import static org.opensearch.knn.common.KNNConstants.PARAMETERS; public class FaissIT extends KNNRestTestCase { + private static final String DOC_ID_1 = "doc1"; + private static final String DOC_ID_2 = "doc2"; + private static final String DOC_ID_3 = "doc3"; + private static final String COLOR_FIELD_NAME = "color"; + private static final String TASTE_FIELD_NAME = "taste"; static TestUtils.TestData testData; @@ -280,4 +288,90 @@ public void testEndToEnd_fromModel() throws Exception { assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId())); } } + + @SneakyThrows + public void testQueryWithFilter_withDifferentCombination_thenSuccess() { + setupKNNIndexForFilterQuery(); + final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + // K > filteredResults + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList(DOC_ID_1, DOC_ID_3); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + // K Limits Filter results + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = List.of(DOC_ID_1); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + + // Empty filter docIds + int k = 10; + final Response emptyFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + kLimitsFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present") + ), + k + ); + final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity()); + final List emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME); + + assertEquals(0, emptyKNNFilteredResultsFromResponse.size()); + } + + protected void setupKNNIndexForFilterQuery() throws Exception { + // Create Mappings + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", 3) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + final String mapping = Strings.toString(builder); + + createKnnIndex(INDEX_NAME, mapping); + + addKnnDocWithAttributes( + DOC_ID_1, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshIndex(INDEX_NAME); + } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index d07efc3839..cb84698dbc 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -25,7 +25,6 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.rest.RestStatus; import java.io.IOException; import java.util.Arrays; @@ -45,8 +44,6 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final String DOC_ID_2 = "doc2"; private static final String DOC_ID_3 = "doc3"; private static final int EF_CONSTRUCTION = 128; - private static final String INDEX_NAME = "test-index-1"; - private static final String FIELD_NAME = "test-field-1"; private static final String COLOR_FIELD_NAME = "color"; private static final String TASTE_FIELD_NAME = "taste"; private static final int M = 16; @@ -361,22 +358,6 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } - private void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { - Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); - for (String fieldName : fieldValues.keySet()) { - builder.field(fieldName, fieldValues.get(fieldName)); - } - builder.endObject(); - request.setJsonEntity(Strings.toString(builder)); - client().performRequest(request); - - request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - } - private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index b52a7238f1..1a9507e6a1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -333,7 +333,7 @@ public static void assertLoadableByEngine( ); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName()); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 8d94b1afb9..ce08e0350f 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName()); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 0f8f43bf21..674d1be39b 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -5,8 +5,11 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.Term; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.mockito.Mockito; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; @@ -23,6 +26,10 @@ import static org.mockito.Mockito.when; public class KNNQueryFactoryTests extends KNNTestCase { + private static final String FILTER_FILED_NAME = "foo"; + private static final String FILTER_FILED_VALUE = "fooval"; + private static final QueryBuilder FILTER_QUERY_BUILDER = new TermQueryBuilder(FILTER_FILED_NAME, FILTER_FILED_VALUE); + private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE)); private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; private final String testIndexName = "test-index"; @@ -59,7 +66,6 @@ public void testCreateLuceneQueryWithFilter() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); MappedFieldType testMapper = mock(MappedFieldType.class); when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); - QueryBuilder filter = new TermQueryBuilder("foo", "fooval"); final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) @@ -67,10 +73,35 @@ public void testCreateLuceneQueryWithFilter() { .vector(testQueryVector) .k(testK) .context(mockQueryShardContext) - .filter(filter) + .filter(FILTER_QUERY_BUILDER) .build(); Query query = KNNQueryFactory.create(createQueryRequest); assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } } + + public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { + final KNNEngine knnEngine = KNNEngine.FAISS; + final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + final Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KNNQuery); + + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testK, ((KNNQuery) query).getK()); + assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 444f763a64..53d0330f0b 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -6,17 +6,24 @@ package org.opensearch.knn.index.query; import com.google.common.collect.Comparators; +import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentCommitInfo; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; import org.junit.BeforeClass; @@ -28,6 +35,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.util.KNNEngine; @@ -70,6 +78,10 @@ public class KNNWeightTests extends KNNTestCase { private static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); + private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); + private static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); + + private static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); private static MockedStatic nativeMemoryCacheManagerMockedStatic; private static MockedStatic jniServiceMockedStatic; @@ -133,7 +145,8 @@ public void testQueryScoreForFaissWithModel() throws IOException { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString())).thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); @@ -272,7 +285,8 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString())).thenReturn(knnQueryResults); + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); @@ -311,12 +325,152 @@ public void testEmptyQueryResults() { assertNull(knnScorer); } + @SneakyThrows + public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))) + .thenReturn(getFilteredKNNQueryResults()); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(reader.maxDoc()).thenReturn(K + 1); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // Just to make sure that we are not hitting the exact search condition + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(K + 1)); + + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + true, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName()); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { + float[] vector = new float[] { 0.1f, 0.3f }; + int filterDocId = 0; + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // scorer will return 2 documents + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); + when(reader.maxDoc()).thenReturn(1); + + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getName()).thenReturn(FIELD_NAME); + when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); + when(binaryDocValues.advance(filterDocId)).thenReturn(filterDocId); + BytesRef vectorByteRef = new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector)); + when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty()); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + + final Scorer knnScorer = knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(0, docIdSetIterator.cost()); + assertEquals(0, docIdSetIterator.cost()); + } + private void testQueryScore( final Function scoreTranslator, final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString())).thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any())) + .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); @@ -381,4 +535,12 @@ private KNNQueryResult[] getKNNQueryResults() { .collect(Collectors.toList()) .toArray(new KNNQueryResult[0]); } + + private KNNQueryResult[] getFilteredKNNQueryResults() { + return FILTERED_DOC_ID_TO_SCORES.entrySet() + .stream() + .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()) + .toArray(new KNNQueryResult[0]); + } } diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index f4971e6fde..6d52e5544d 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -583,12 +583,12 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid-engine")); + expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null)); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -611,7 +611,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -637,7 +637,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName()); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null); assertEquals(k, results.length); } } @@ -645,7 +645,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -664,7 +664,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null)); } public void testQueryIndex_faiss_valid() throws IOException { @@ -693,9 +693,15 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null); assertEquals(k, results.length); } + + // Filter will result in no ids + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }); + assertEquals(0, results.length); + } } } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 41d53d72c7..8c1ff63dbe 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1306,4 +1306,20 @@ protected void refreshIndex(final String index) throws IOException { Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + + protected void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { + Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); + + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); + for (String fieldName : fieldValues.keySet()) { + builder.field(fieldName, fieldValues.get(fieldName)); + } + builder.endObject(); + request.setJsonEntity(Strings.toString(builder)); + client().performRequest(request); + + request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } }