Skip to content

Commit

Permalink
Adds in lazy execution for Lucene kNN queries
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <[email protected]>
  • Loading branch information
kotwanikunal committed Dec 3, 2024
1 parent 0bbeac3 commit 341772b
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand Down Expand Up @@ -106,9 +107,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return new LuceneEngineKnnVectorQuery(getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter));
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return new LuceneEngineKnnVectorQuery(getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter));
default:
throw new IllegalArgumentException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;

@AllArgsConstructor
@Log4j2
public class LuceneEngineKnnVectorQuery extends Query {
private final Query luceneQuery;

@Override
public Query rewrite(IndexSearcher indexSearcher) {
return this;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
Query rewrittenQuery = luceneQuery.rewrite(searcher);
return rewrittenQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public String toString(String s) {
return luceneQuery.toString();
}

@Override
public void visit(QueryVisitor queryVisitor) {
queryVisitor.visitLeaf(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LuceneEngineKnnVectorQuery otherQuery = (LuceneEngineKnnVectorQuery) o;
return luceneQuery.equals(otherQuery.luceneQuery);
}

@Override
public int hashCode() {
return luceneQuery.hashCode();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.MockitoAnnotations.openMocks;

public class LuceneEngineKnnVectorQueryTests extends OpenSearchTestCase {

@Mock
IndexSearcher indexSearcher;

@Mock
Query luceneQuery;

@Mock
Weight weight;

@Mock
QueryVisitor queryVisitor;

@Spy
@InjectMocks
LuceneEngineKnnVectorQuery objectUnderTest;

@Override
public void setUp() throws Exception {
super.setUp();
openMocks(this);
when(luceneQuery.rewrite(any(IndexSearcher.class))).thenReturn(luceneQuery);
when(luceneQuery.createWeight(any(IndexSearcher.class), any(ScoreMode.class), anyFloat())).thenReturn(weight);
}

public void testRewrite() {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
}

public void testCreateWeight() throws Exception {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
Weight actualWeight = objectUnderTest.createWeight(indexSearcher, ScoreMode.TOP_DOCS, 1.0f);
verify(luceneQuery, times(1)).rewrite(indexSearcher);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
assertEquals(weight, actualWeight);
}

public void testVisit() {
objectUnderTest.visit(queryVisitor);
verify(queryVisitor).visitLeaf(objectUnderTest);
}

public void testEquals() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
LuceneEngineKnnVectorQuery otherQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery, otherQuery);
assertEquals(mainQuery, mainQuery);
assertNotEquals(mainQuery, null);
assertNotEquals(mainQuery, new Object());
LuceneEngineKnnVectorQuery otherQuery2 = new LuceneEngineKnnVectorQuery(null);
assertNotEquals(mainQuery, otherQuery2);
}

public void testHashCode() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.hashCode(), luceneQuery.hashCode());
}

public void testToString() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.toString(), luceneQuery.toString());
}
}

0 comments on commit 341772b

Please sign in to comment.