diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 8e72b2429..11f978151 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -153,10 +153,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (FILTER_FIELD.getPreferredName().equals(tokenName)) { filter = parseInnerQueryBuilder(parser); } else { - throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" - ); + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } } else { @@ -246,7 +243,16 @@ protected Query doToQuery(QueryShardContext context) { } String indexName = context.index().getName(); - return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k, this.filter, context); + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(this.vector) + .k(this.k) + .knnQueryFilter(this.filter) + .context(context) + .build(); + return KNNQueryFactory.create(createQueryRequest); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 619242d5c..eb5236d26 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,6 +5,11 @@ package org.opensearch.knn.index.query; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; @@ -13,6 +18,7 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Optional; /** * Creates the Lucene k-NN queries @@ -30,31 +36,98 @@ public class KNNQueryFactory { * @param k the number of nearest neighbors to return * @return Lucene Query */ - public static Query create( - KNNEngine knnEngine, - String indexName, - String fieldName, - float[] vector, - int k, - QueryBuilder knnQueryFilter, - QueryShardContext context - ) { + public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) { + final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(fieldName) + .vector(vector) + .k(k) + .build(); + return create(createQueryRequest); + } + + /** + * Creates a Lucene query for a particular engine. + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + public static Query create(CreateQueryRequest createQueryRequest) { // Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to // use the custom query type created by the plugin - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) { - log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + log.debug( + String.format( + "Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName(), + createQueryRequest.getK() + ) + ); + return new KNNQuery( + createQueryRequest.getFieldName(), + createQueryRequest.getVector(), + createQueryRequest.getK(), + createQueryRequest.getIndexName() + ); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - if (knnQueryFilter == null) { - return new KnnVectorQuery(fieldName, vector, k); + log.debug( + String.format( + "Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName(), + createQueryRequest.getK() + ) + ); + if (createQueryRequest.getKnnQueryFilter().isPresent()) { + final QueryShardContext queryShardContext = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + try { + final Query filterQuery = createQueryRequest.getKnnQueryFilter().get().toQuery(queryShardContext); + return new KnnVectorQuery( + createQueryRequest.getFieldName(), + createQueryRequest.getVector(), + createQueryRequest.getK(), + filterQuery + ); + } catch (IOException e) { + throw new RuntimeException("Cannot create knn query with filter", e); + } } - try { - Query filterQuery = knnQueryFilter.toQuery(context); - return new KnnVectorQuery(fieldName, vector, k, filterQuery); - } catch (IOException e) { - throw new RuntimeException("Cannot create knn query with filter", e); + return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK()); + } + + /** + * DTO object to hold data required to create a Query instance. + */ + @AllArgsConstructor + @Builder + @Setter + static class CreateQueryRequest { + @Getter + @NonNull + private KNNEngine knnEngine; + @Getter + @NonNull + private String indexName; + @Getter + private String fieldName; + @Getter + private float[] vector; + @Getter + private int k; + // can be null in cases filter not passed with the knn query + private QueryBuilder knnQueryFilter; + // can be null in cases filter not passed with the knn query + private QueryShardContext context; + + public Optional getKnnQueryFilter() { + return Optional.ofNullable(knnQueryFilter); + } + + public Optional getContext() { + return Optional.ofNullable(context); } } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java index 1f7d6cc85..8a3233fba 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java @@ -17,7 +17,6 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.opensearch.index.mapper.MapperService; -import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -109,16 +108,7 @@ public void testKnnVectorIndex() throws Exception { verify(knnVectorsFormat).getKnnVectorsFormatForField(anyString()); IndexSearcher searcher = new IndexSearcher(reader); - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - Query query = KNNQueryFactory.create( - KNNEngine.LUCENE, - "dummy", - fieldName, - new float[] { 1.0f, 0.0f, 0.0f }, - 1, - null, - mockQueryShardContext - ); + Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", fieldName, new float[] { 1.0f, 0.0f, 0.0f }, 1); assertEquals(1, searcher.count(query)); @@ -145,15 +135,7 @@ public void testKnnVectorIndex() throws Exception { verify(knnVectorsFormat, times(2)).getKnnVectorsFormatForField(anyString()); IndexSearcher searcher1 = new IndexSearcher(reader1); - Query query1 = KNNQueryFactory.create( - KNNEngine.LUCENE, - "dummy", - field1Name, - new float[] { 1.0f, 0.0f }, - 1, - null, - mockQueryShardContext - ); + Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", field1Name, new float[] { 1.0f, 0.0f }, 1); assertEquals(1, searcher1.count(query1)); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index d08f8de97..6d041687d 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -7,7 +7,10 @@ import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.util.KNNEngine; @@ -15,7 +18,9 @@ import java.util.List; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class KNNQueryFactoryTests extends KNNTestCase { private final int testQueryDimension = 17; @@ -26,16 +31,7 @@ public class KNNQueryFactoryTests extends KNNTestCase { public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - Query query = KNNQueryFactory.create( - knnEngine, - testIndexName, - testFieldName, - testQueryVector, - testK, - null, - mockQueryShardContext - ); + Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); @@ -46,20 +42,34 @@ public void testCreateCustomKNNQuery() { } public void testCreateLuceneDefaultQuery() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); + assertTrue(query instanceof KnnVectorQuery); + } + } + + public void testCreateLuceneQueryWithFilter() { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) .collect(Collectors.toList()); for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - Query query = KNNQueryFactory.create( - knnEngine, - testIndexName, - testFieldName, - testQueryVector, - testK, - null, - mockQueryShardContext - ); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + QueryBuilder filter = new TermQueryBuilder("foo", "fooval"); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .knnQueryFilter(filter) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); assertTrue(query instanceof KnnVectorQuery); } }