From fcbfef13668e773befd94579f41a8a82943ffda0 Mon Sep 17 00:00:00 2001 From: panguixin Date: Wed, 24 Jan 2024 01:51:33 +0800 Subject: [PATCH] Fix KNNScorer to apply boost (#1403) * apply boost Signed-off-by: panguixin * add change log Signed-off-by: panguixin --------- Signed-off-by: panguixin --- CHANGELOG.md | 1 + .../opensearch/knn/index/query/KNNScorer.java | 2 +- .../knn/index/query/KNNWeightTests.java | 35 +++++++++++-------- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 961ab62af..f57bc71fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) * Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317) * Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367) +* Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403) ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) * Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 3e5c8fff6..02dc86e80 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -49,7 +49,7 @@ public float score() { assert docID() != DocIdSetIterator.NO_MORE_DOCS; Float score = scores.get(docID()); if (score == null) throw new RuntimeException("Null score for the docID: " + docID()); - return score; + return score * boost; } @Override diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 2ed00eee9..f93ed51c1 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -169,7 +169,8 @@ public void testQueryScoreForFaissWithModel() { when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata); KNNWeight.initialize(modelDao); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -214,7 +215,7 @@ public void testQueryScoreForFaissWithModel() { final Map translatedScores = getTranslatedScores(scoreTranslator); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); @@ -364,7 +365,8 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { // Just to make sure that we are not hitting the exact search condition when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final FSDirectory directory = mock(FSDirectory.class); when(reader.directory()).thenReturn(directory); @@ -408,7 +410,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); @@ -433,7 +435,8 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -457,7 +460,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { final List actualDocIds = new ArrayList<>(); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); @@ -483,7 +486,8 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -507,7 +511,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS final List actualDocIds = new ArrayList<>(); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); @@ -543,7 +547,8 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -567,7 +572,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces final List actualDocIds = new ArrayList<>(); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); @@ -631,7 +636,8 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, parentFilter); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); // Execute final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); @@ -642,7 +648,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { .collect(Collectors.toList()); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertEquals(1, docIdSetIterator.nextDoc()); - assertEquals(expectedScores.get(1), knnScorer.score(), 0.01f); + assertEquals(expectedScores.get(1) * boost, knnScorer.score(), 0.01f); assertEquals(NO_MORE_DOCS, docIdSetIterator.nextDoc()); } @@ -733,7 +739,8 @@ private void testQueryScore( .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -777,7 +784,7 @@ private void testQueryScore( final Map translatedScores = getTranslatedScores(scoreTranslator); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));