Skip to content

Commit

Permalink
Add parent join support for lucene knn
Browse files Browse the repository at this point in the history
Call DiversifyingChildren[Byte|Float]KnnVectorQuery for nested field so that k number of parent document can be returned in search result

Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Oct 2, 2023
1 parent 78aba55 commit b04233b
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
* Add parent join support for lucene knn [#1181](https://github.com/opensearch-project/k-NN/pull/1181)
### Enhancements
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ public class KNNConstants {
public static final String NAME = "name";
public static final String PARAMETERS = "parameters";
public static final String METHOD_HNSW = "hnsw";
public static final String TYPE = "type";
public static final String TYPE_NESTED = "nested";
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search";
public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction";
public static final String METHOD_PARAMETER_M = "m";
Expand Down
55 changes: 26 additions & 29 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -87,9 +90,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
}

if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery);
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, createQueryRequest.context.getParentFilter());
} else if (VectorDataType.FLOAT == vectorDataType) {
return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery);
return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, createQueryRequest.context.getParentFilter());
} else {
throw new IllegalArgumentException(
String.format(
Expand All @@ -102,38 +105,30 @@ public static Query create(CreateQueryRequest createQueryRequest) {
}
}

private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnByteVectorQuery(final String fieldName, final byte[] byteVector, final int k, final Query filterQuery, final BitSetProducer parentFilter) {
if (parentFilter == null) {
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnByteVectorQuery(fieldName, byteVector, k);
else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
}
}

private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnFloatVectorQuery(final String fieldName, final float[] floatVector, final int k, final Query filterQuery, final BitSetProducer parentFilter) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k);
}
else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnFloatVectorQuery(fieldName, floatVector, k);
}

private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
Expand Down Expand Up @@ -181,6 +176,8 @@ static class CreateQueryRequest {
@Getter
private int k;
// can be null in cases filter not passed with the knn query
@Getter
public BitSetProducer parentFilter;
private QueryBuilder filter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;
Expand Down
202 changes: 202 additions & 0 deletions src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.K;
import static org.opensearch.knn.common.KNNConstants.KNN;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.PATH;
import static org.opensearch.knn.common.KNNConstants.QUERY;
import static org.opensearch.knn.common.KNNConstants.TYPE;
import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR;
import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED;
import static org.opensearch.knn.common.KNNConstants.VECTOR;

public class NestedSearchIT extends KNNRestTestCase {
private static final String INDEX_NAME = "test-index-nested-search";
private static final String FIELD_NAME_NESTED = "test-nested";
private static final String FIELD_NAME_VECTOR = "test-vector";
private static final String PROPERTIES_FIELD = "properties";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final SpaceType SPACE_TYPE = SpaceType.L2;

@After
@SneakyThrows
public final void cleanUp() {
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.LUCENE.getName());

String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{1f, 1f}, new Float[]{1f, 1f})
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{2f, 2f}, new Float[]{2f, 2f})
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);

Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);

List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
EntityUtils.toString(response.getEntity())
).map().get("hits")).get("hits");
assertEquals(2, hits.size());
}

/**
* {
* "properties": {
* "test-nested": {
* "type": "nested",
* "properties": {
* "test-vector": {
* "type": "knn_vector",
* "dimension": 3,
* "method": {
* "name": "hnsw",
* "space_type": "l2",
* "engine": "lucene",
* "parameters": {
* "ef_construction": 128,
* "m": 24
* }
* }
* }
* }
* }
* }
* }
*/
private void createKnnIndex(final int dimension, final String engine)
throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME_NESTED)
.field(TYPE, TYPE_NESTED)
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME_VECTOR)
.field(TYPE, TYPE_KNN_VECTOR)
.field(DIMENSION, dimension)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SPACE_TYPE)
.field(KNN_ENGINE, engine)
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_M, M)
.field(METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}

@SneakyThrows
private void ingestTestData() {
String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{1f, 1f}, new Float[]{1f, 1f})
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{2f, 2f}, new Float[]{2f, 2f})
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);
}

private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

request.setJsonEntity(document);
client().performRequest(request);

request = new Request("POST", "/" + index + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
builder.startObject(TYPE_NESTED);
builder.field(PATH, FIELD_NAME_NESTED);
builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR);
builder.field(VECTOR, vector);
builder.field(K, k);
builder.endObject().endObject().endObject().endObject().endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");
request.setJsonEntity(builder.toString());

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

return response;
}

private static class NestedKnnDocBuilder {
private XContentBuilder builder;
public NestedKnnDocBuilder(final String fieldName) throws IOException {
builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName);
}

public static NestedKnnDocBuilder create(final String fieldName) throws IOException {
return new NestedKnnDocBuilder(fieldName);
}

public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException {
for (Object[] vector : vectors) {
builder.startObject();
builder.field(fieldName, vector);
builder.endObject();
}
return this;
}

public String build() throws IOException {
builder.endArray().endObject();
return builder.toString();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.mockito.Mockito;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
Expand All @@ -33,6 +37,7 @@ public class KNNQueryFactoryTests extends KNNTestCase {
private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE));
private final int testQueryDimension = 17;
private final float[] testQueryVector = new float[testQueryDimension];
private final byte[] testByteQueryVector = new byte[testQueryDimension];
private final String testIndexName = "test-index";
private final String testFieldName = "test-field";
private final int testK = 10;
Expand Down Expand Up @@ -120,4 +125,34 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() {
assertEquals(testK, ((KNNQuery) query).getK());
assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery());
}

public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() {
validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class);
}
private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
.collect(Collectors.toList());
for (KNNEngine knnEngine : luceneDefaultQueryEngineList) {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
BitSetProducer parentFilter = mock(BitSetProducer.class);
when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter);
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.byteVector(testByteQueryVector)
.vectorDataType(type)
.k(testK)
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query.getClass().isAssignableFrom(expectedQueryClass));
}
}
}

0 comments on commit b04233b

Please sign in to comment.