From b04233b45535b5ec2314007dbb947aadd2a95413 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Mon, 2 Oct 2023 10:07:00 -0700 Subject: [PATCH] Add parent join support for lucene knn Call DiversifyingChildren[Byte|Float]KnnVectorQuery for nested field so that k number of parent document can be returned in search result Signed-off-by: Heemin Kim --- CHANGELOG.md | 1 + .../opensearch/knn/common/KNNConstants.java | 8 + .../knn/index/query/KNNQueryFactory.java | 55 +++-- .../opensearch/knn/index/NestedSearchIT.java | 202 ++++++++++++++++++ .../knn/index/query/KNNQueryFactoryTests.java | 35 +++ 5 files changed, 272 insertions(+), 29 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/NestedSearchIT.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b29bbba47..fa707d05e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x) ### Features +* Add parent join support for lucene knn [#1181](https://github.com/opensearch-project/k-NN/pull/1181) ### Enhancements * Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index d7835a9c25..85654efd09 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -15,6 +15,14 @@ public class KNNConstants { public static final String NAME = "name"; public static final String PARAMETERS = "parameters"; public static final String METHOD_HNSW = "hnsw"; + public static final String TYPE = "type"; + public static final String TYPE_NESTED = "nested"; + public static final String PATH = "path"; + public static final String QUERY = "query"; + public static final String KNN = "knn"; + public static final String VECTOR = "vector"; + public static final String K = "k"; + public static final String TYPE_KNN_VECTOR = "knn_vector"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; public static final String METHOD_PARAMETER_M = "m"; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index b05098f284..742cd9bdbc 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -14,6 +14,9 @@ import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; @@ -87,9 +90,9 @@ public static Query create(CreateQueryRequest createQueryRequest) { } if (VectorDataType.BYTE == vectorDataType) { - return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery); + return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, createQueryRequest.context.getParentFilter()); } else if (VectorDataType.FLOAT == vectorDataType) { - return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery); + return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, createQueryRequest.context.getParentFilter()); } else { throw new IllegalArgumentException( String.format( @@ -102,38 +105,30 @@ public static Query create(CreateQueryRequest createQueryRequest) { } } - private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) { - if (filterQuery != null) { - log.debug( - String.format( - Locale.ROOT, - "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", - indexName, - fieldName, - k - ) - ); + /** + * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery} + * which will dedupe search result per parent so that we can get k parent results at the end. + */ + private static Query getKnnByteVectorQuery(final String fieldName, final byte[] byteVector, final int k, final Query filterQuery, final BitSetProducer parentFilter) { + if (parentFilter == null) { return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnByteVectorQuery(fieldName, byteVector, k); + else { + return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter); + } } - private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) { - if (filterQuery != null) { - log.debug( - String.format( - Locale.ROOT, - "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", - indexName, - fieldName, - k - ) - ); - return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery); + /** + * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery} + * which will dedupe search result per parent so that we can get k parent results at the end. + */ + private static Query getKnnFloatVectorQuery(final String fieldName, final float[] floatVector, final int k, final Query filterQuery, final BitSetProducer parentFilter) { + if (parentFilter == null) { + return new KnnFloatVectorQuery(fieldName, floatVector, k); + } + else { + return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnFloatVectorQuery(fieldName, floatVector, k); } private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { @@ -181,6 +176,8 @@ static class CreateQueryRequest { @Getter private int k; // can be null in cases filter not passed with the knn query + @Getter + public BitSetProducer parentFilter; private QueryBuilder filter; // can be null in cases filter not passed with the knn query private QueryShardContext context; diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java new file mode 100644 index 0000000000..5a9160ed40 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -0,0 +1,202 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.K; +import static org.opensearch.knn.common.KNNConstants.KNN; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.PATH; +import static org.opensearch.knn.common.KNNConstants.QUERY; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED; +import static org.opensearch.knn.common.KNNConstants.VECTOR; + +public class NestedSearchIT extends KNNRestTestCase { + private static final String INDEX_NAME = "test-index-nested-search"; + private static final String FIELD_NAME_NESTED = "test-nested"; + private static final String FIELD_NAME_VECTOR = "test-vector"; + private static final String PROPERTIES_FIELD = "properties"; + private static final int EF_CONSTRUCTION = 128; + private static final int M = 16; + private static final SpaceType SPACE_TYPE = SpaceType.L2; + + @After + @SneakyThrows + public final void cleanUp() { + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() { + createKnnIndex(2, KNNEngine.LUCENE.getName()); + + String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[]{1f, 1f}, new Float[]{1f, 1f}) + .build(); + addNestedKnnDoc(INDEX_NAME, "1", doc1); + + String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[]{2f, 2f}, new Float[]{2f, 2f}) + .build(); + addNestedKnnDoc(INDEX_NAME, "2", doc2); + + Float[] queryVector = { 1f, 1f }; + Response response = queryNestedField(INDEX_NAME, 2, queryVector); + + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + EntityUtils.toString(response.getEntity()) + ).map().get("hits")).get("hits"); + assertEquals(2, hits.size()); + } + + /** + * { + * "properties": { + * "test-nested": { + * "type": "nested", + * "properties": { + * "test-vector": { + * "type": "knn_vector", + * "dimension": 3, + * "method": { + * "name": "hnsw", + * "space_type": "l2", + * "engine": "lucene", + * "parameters": { + * "ef_construction": 128, + * "m": 24 + * } + * } + * } + * } + * } + * } + * } + */ + private void createKnnIndex(final int dimension, final String engine) + throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_NESTED) + .field(TYPE, TYPE_NESTED) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_VECTOR) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SPACE_TYPE) + .field(KNN_ENGINE, engine) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, M) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, mapping); + } + + @SneakyThrows + private void ingestTestData() { + String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[]{1f, 1f}, new Float[]{1f, 1f}) + .build(); + addNestedKnnDoc(INDEX_NAME, "1", doc1); + + String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .add(FIELD_NAME_VECTOR, new Float[]{2f, 2f}, new Float[]{2f, 2f}) + .build(); + addNestedKnnDoc(INDEX_NAME, "2", doc2); + } + + private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + + request.setJsonEntity(document); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); + builder.field(VECTOR, vector); + builder.field(K, k); + builder.endObject().endObject().endObject().endObject().endObject().endObject(); + + Request request = new Request("POST", "/" + index + "/_search"); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + return response; + } + + private static class NestedKnnDocBuilder { + private XContentBuilder builder; + public NestedKnnDocBuilder(final String fieldName) throws IOException { + builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName); + } + + public static NestedKnnDocBuilder create(final String fieldName) throws IOException { + return new NestedKnnDocBuilder(fieldName); + } + + public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException { + for (Object[] vector : vectors) { + builder.startObject(); + builder.field(fieldName, vector); + builder.endObject(); + } + return this; + } + + public String build() throws IOException { + builder.endArray().endObject(); + return builder.toString(); + } + + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 4dccfd0877..cd969f372c 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -9,12 +9,16 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.mockito.Mockito; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.util.Arrays; @@ -33,6 +37,7 @@ public class KNNQueryFactoryTests extends KNNTestCase { private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE)); private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; + private final byte[] testByteQueryVector = new byte[testQueryDimension]; private final String testIndexName = "test-index"; private final String testFieldName = "test-field"; private final int testK = 10; @@ -120,4 +125,34 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { assertEquals(testK, ((KNNQuery) query).getK()); assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery()); } + + public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() { + validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class); + validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class); + } + private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .byteVector(testByteQueryVector) + .vectorDataType(type) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query.getClass().isAssignableFrom(expectedQueryClass)); + } + } }