Skip to content

Commit

Permalink
Adding efficient filtering (#515)
Browse files Browse the repository at this point in the history
* Add initial support for filtering 

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Aug 22, 2022
1 parent 4f9c6cb commit 86cec30
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 7 deletions.
30 changes: 28 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -38,6 +39,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {

public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -49,6 +51,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final String fieldName;
private final float[] vector;
private int k = 0;
private QueryBuilder filter;

/**
* Constructs a new knn query
Expand All @@ -58,6 +61,10 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder(String fieldName, float[] vector, int k) {
this(fieldName, vector, k, null);
}

public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException("[" + NAME + "] requires fieldName");
}
Expand All @@ -77,6 +84,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) {
this.fieldName = fieldName;
this.vector = vector;
this.k = k;
this.filter = filter;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -111,6 +119,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
int k = 0;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
XContentParser.Token token;
Expand Down Expand Up @@ -139,6 +148,14 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
filter = parseInnerQueryBuilder(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}

} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand All @@ -153,7 +170,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
Expand Down Expand Up @@ -226,7 +243,16 @@ protected Query doToQuery(QueryShardContext context) {
}

String indexName = context.index().getName();
return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k);
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(this.vector)
.k(this.k)
.filter(this.filter)
.context(context)
.build();
return KNNQueryFactory.create(createQueryRequest);
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down
103 changes: 98 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@

package org.opensearch.knn.index.query;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Optional;

/**
* Creates the Lucene k-NN queries
*/
Expand All @@ -27,14 +37,97 @@ public class KNNQueryFactory {
* @return Lucene Query
*/
public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) {
final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(fieldName)
.vector(vector)
.k(k)
.build();
return create(createQueryRequest);
}

/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
* @return Lucene Query
*/
public static Query create(CreateQueryRequest createQueryRequest) {
// Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to
// use the custom query type created by the plugin
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) {
log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KNNQuery(fieldName, vector, k, indexName);
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
log.debug(
String.format(
"Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName(),
createQueryRequest.getK()
)
);
return new KNNQuery(
createQueryRequest.getFieldName(),
createQueryRequest.getVector(),
createQueryRequest.getK(),
createQueryRequest.getIndexName()
);
}

log.debug(
String.format(
"Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName(),
createQueryRequest.getK()
)
);
if (createQueryRequest.getFilter().isPresent()) {
final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
try {
final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
return new KnnVectorQuery(
createQueryRequest.getFieldName(),
createQueryRequest.getVector(),
createQueryRequest.getK(),
filterQuery
);
} catch (IOException e) {
throw new RuntimeException("Cannot create knn query with filter", e);
}
}
return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK());
}

/**
* DTO object to hold data required to create a Query instance.
*/
@AllArgsConstructor
@Builder
@Setter
static class CreateQueryRequest {
@Getter
@NonNull
private KNNEngine knnEngine;
@Getter
@NonNull
private String indexName;
@Getter
private String fieldName;
@Getter
private float[] vector;
@Getter
private int k;
// can be null in cases filter not passed with the knn query
private QueryBuilder filter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;

public Optional<QueryBuilder> getFilter() {
return Optional.ofNullable(filter);
}

log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnVectorQuery(fieldName, vector, k);
public Optional<QueryShardContext> getContext() {
return Optional.ofNullable(context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@

import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class KNNQueryFactoryTests extends KNNTestCase {
private final int testQueryDimension = 17;
private final float[] testQueryVector = new float[testQueryDimension];
Expand Down Expand Up @@ -42,4 +50,27 @@ public void testCreateLuceneDefaultQuery() {
assertTrue(query instanceof KnnVectorQuery);
}
}

public void testCreateLuceneQueryWithFilter() {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
.collect(Collectors.toList());
for (KNNEngine knnEngine : luceneDefaultQueryEngineList) {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
QueryBuilder filter = new TermQueryBuilder("foo", "fooval");
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.k(testK)
.context(mockQueryShardContext)
.filter(filter)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query instanceof KnnVectorQuery);
}
}
}

0 comments on commit 86cec30

Please sign in to comment.