Skip to content

Commit

Permalink
Adding basic unit test, refactor query factory to use request dto
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 18, 2022
1 parent 3c375c4 commit 5b0bb27
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 64 deletions.
16 changes: 11 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
filter = parseInnerQueryBuilder(parser);
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}

} else {
Expand Down Expand Up @@ -246,7 +243,16 @@ protected Query doToQuery(QueryShardContext context) {
}

String indexName = context.index().getName();
return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k, this.filter, context);
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(this.vector)
.k(this.k)
.knnQueryFilter(this.filter)
.context(context)
.build();
return KNNQueryFactory.create(createQueryRequest);
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down
113 changes: 93 additions & 20 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

package org.opensearch.knn.index.query;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
Expand All @@ -13,6 +18,7 @@
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Optional;

/**
* Creates the Lucene k-NN queries
Expand All @@ -30,31 +36,98 @@ public class KNNQueryFactory {
* @param k the number of nearest neighbors to return
* @return Lucene Query
*/
public static Query create(
KNNEngine knnEngine,
String indexName,
String fieldName,
float[] vector,
int k,
QueryBuilder knnQueryFilter,
QueryShardContext context
) {
public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) {
final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(fieldName)
.vector(vector)
.k(k)
.build();
return create(createQueryRequest);
}

/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
* @return Lucene Query
*/
public static Query create(CreateQueryRequest createQueryRequest) {
// Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to
// use the custom query type created by the plugin
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) {
log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KNNQuery(fieldName, vector, k, indexName);
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
log.debug(
String.format(
"Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName(),
createQueryRequest.getK()
)
);
return new KNNQuery(
createQueryRequest.getFieldName(),
createQueryRequest.getVector(),
createQueryRequest.getK(),
createQueryRequest.getIndexName()
);
}

log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
if (knnQueryFilter == null) {
return new KnnVectorQuery(fieldName, vector, k);
log.debug(
String.format(
"Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName(),
createQueryRequest.getK()
)
);
if (createQueryRequest.getKnnQueryFilter().isPresent()) {
final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
try {
final Query filterQuery = createQueryRequest.getKnnQueryFilter().get().toQuery(queryShardContext);
return new KnnVectorQuery(
createQueryRequest.getFieldName(),
createQueryRequest.getVector(),
createQueryRequest.getK(),
filterQuery
);
} catch (IOException e) {
throw new RuntimeException("Cannot create knn query with filter", e);
}
}
try {
Query filterQuery = knnQueryFilter.toQuery(context);
return new KnnVectorQuery(fieldName, vector, k, filterQuery);
} catch (IOException e) {
throw new RuntimeException("Cannot create knn query with filter", e);
return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK());
}

/**
* DTO object to hold data required to create a Query instance.
*/
@AllArgsConstructor
@Builder
@Setter
static class CreateQueryRequest {
@Getter
@NonNull
private KNNEngine knnEngine;
@Getter
@NonNull
private String indexName;
@Getter
private String fieldName;
@Getter
private float[] vector;
@Getter
private int k;
// can be null in cases filter not passed with the knn query
private QueryBuilder knnQueryFilter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;

public Optional<QueryBuilder> getKnnQueryFilter() {
return Optional.ofNullable(knnQueryFilter);
}

public Optional<QueryShardContext> getContext() {
return Optional.ofNullable(context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -109,16 +108,7 @@ public void testKnnVectorIndex() throws Exception {
verify(knnVectorsFormat).getKnnVectorsFormatForField(anyString());

IndexSearcher searcher = new IndexSearcher(reader);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
Query query = KNNQueryFactory.create(
KNNEngine.LUCENE,
"dummy",
fieldName,
new float[] { 1.0f, 0.0f, 0.0f },
1,
null,
mockQueryShardContext
);
Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", fieldName, new float[] { 1.0f, 0.0f, 0.0f }, 1);

assertEquals(1, searcher.count(query));

Expand All @@ -145,15 +135,7 @@ public void testKnnVectorIndex() throws Exception {
verify(knnVectorsFormat, times(2)).getKnnVectorsFormatForField(anyString());

IndexSearcher searcher1 = new IndexSearcher(reader1);
Query query1 = KNNQueryFactory.create(
KNNEngine.LUCENE,
"dummy",
field1Name,
new float[] { 1.0f, 0.0f },
1,
null,
mockQueryShardContext
);
Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", field1Name, new float[] { 1.0f, 0.0f }, 1);

assertEquals(1, searcher1.count(query1));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class KNNQueryFactoryTests extends KNNTestCase {
private final int testQueryDimension = 17;
Expand All @@ -26,16 +31,7 @@ public class KNNQueryFactoryTests extends KNNTestCase {

public void testCreateCustomKNNQuery() {
for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
Query query = KNNQueryFactory.create(
knnEngine,
testIndexName,
testFieldName,
testQueryVector,
testK,
null,
mockQueryShardContext
);
Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK);
assertTrue(query instanceof KNNQuery);

assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
Expand All @@ -46,20 +42,34 @@ public void testCreateCustomKNNQuery() {
}

public void testCreateLuceneDefaultQuery() {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
.collect(Collectors.toList());
for (KNNEngine knnEngine : luceneDefaultQueryEngineList) {
Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK);
assertTrue(query instanceof KnnVectorQuery);
}
}

public void testCreateLuceneQueryWithFilter() {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
.collect(Collectors.toList());
for (KNNEngine knnEngine : luceneDefaultQueryEngineList) {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
Query query = KNNQueryFactory.create(
knnEngine,
testIndexName,
testFieldName,
testQueryVector,
testK,
null,
mockQueryShardContext
);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
QueryBuilder filter = new TermQueryBuilder("foo", "fooval");
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.k(testK)
.context(mockQueryShardContext)
.knnQueryFilter(filter)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query instanceof KnnVectorQuery);
}
}
Expand Down

0 comments on commit 5b0bb27

Please sign in to comment.