From c3d1a24301ebb71a6769ed13d9a27fcba63fb7d8 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 17 Jul 2023 16:30:52 -0700 Subject: [PATCH] Enabled the IVF algorithm to work with Filters of K-NN Query. Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + .../opensearch/knn/index/query/KNNWeight.java | 20 ++++++++++- .../org/opensearch/knn/index/FaissIT.java | 35 ++++++++++++++++--- .../org/opensearch/knn/KNNRestTestCase.java | 22 ++++++++++++ 4 files changed, 73 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb2b04796e..f534d13be5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.9...2.x) ### Features ### Enhancements +1. Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/1013) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index b8b88b4fea..c41487a8f4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; import org.apache.lucene.search.FilteredDocIdSetIterator; @@ -49,6 +50,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -290,7 +292,7 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont float[] queryVector = this.knnQuery.getQueryVector(); try { final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); - final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE)); + final SpaceType spaceType = getSpaceType(fieldInfo); // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); ScoreDoc topDoc = queue.top(); @@ -351,4 +353,20 @@ public static float normalizeScore(float score) { if (score >= 0) return 1 / (1 + score); return -score + 1; } + + private SpaceType getSpaceType(final FieldInfo fieldInfo) { + final String spaceTypeString = fieldInfo.getAttribute(SPACE_TYPE); + if (StringUtils.isNotEmpty(spaceTypeString)) { + return SpaceType.getSpace(spaceTypeString); + } + + final String modelId = fieldInfo.getAttribute(MODEL_ID); + if (StringUtils.isNotEmpty(modelId)) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + return modelMetadata.getSpaceType(); + } + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unable to find the Model Id from Field Info attributes for field %s", fieldInfo.getName()) + ); + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 8eff19da5a..a579fb3fde 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -221,7 +221,7 @@ public void testDocDeletion() throws IOException { deleteKnnDoc(INDEX_NAME, "1"); } - public void testEndToEnd_fromModel() throws IOException, InterruptedException { + public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws IOException, InterruptedException { String modelId = "test-model"; int dimension = 128; @@ -270,10 +270,9 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException { // Index some documents int numDocs = 100; for (int i = 0; i < numDocs; i++) { - Float[] indexVector = new Float[dimension]; + float[] indexVector = new float[dimension]; Arrays.fill(indexVector, (float) i); - - addKnnDoc(indexName, Integer.toString(i), fieldName, indexVector); + addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating", String.valueOf(i))); } // Run search and ensure that the values returned are expected @@ -287,6 +286,34 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException { for (int i = 0; i < k; i++) { assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId())); } + + // doing exact search with filters + Response exactSearchFilteredResponse = searchKNNIndex( + indexName, + new KNNQueryBuilder(fieldName, queryVector, k, QueryBuilders.rangeQuery("rating").gte("90").lte("99")), + k + ); + List exactSearchFilteredResults = parseSearchResponse( + EntityUtils.toString(exactSearchFilteredResponse.getEntity()), + fieldName + ); + for (int i = 0; i < k; i++) { + assertEquals(numDocs - i - 1, Integer.parseInt(exactSearchFilteredResults.get(i).getDocId())); + } + + // doing exact search with filters + Response aNNSearchFilteredResponse = searchKNNIndex( + indexName, + new KNNQueryBuilder(fieldName, queryVector, k, QueryBuilders.rangeQuery("rating").gte("80").lte("99")), + k + ); + List aNNSearchFilteredResults = parseSearchResponse( + EntityUtils.toString(aNNSearchFilteredResponse.getEntity()), + fieldName + ); + for (int i = 0; i < k; i++) { + assertEquals(numDocs - i - 1, Integer.parseInt(aNNSearchFilteredResults.get(i).getDocId())); + } } @SneakyThrows diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index b6634a6121..1efde29c12 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1322,4 +1322,26 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues + ) throws IOException { + Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true"); + + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(vectorFieldName, vector); + for (String fieldName : fieldValues.keySet()) { + builder.field(fieldName, fieldValues.get(fieldName)); + } + builder.endObject(); + request.setJsonEntity(Strings.toString(builder)); + client().performRequest(request); + + request = new Request("POST", "/" + indexName + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } }