From 4c8cf9350959edf9b8adab252bd3f1d3f96e8ca9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 22 Aug 2022 10:45:01 -0700 Subject: [PATCH 1/8] Adding efficient filtering (#515) * Add initial support for filtering Signed-off-by: Martin Gaievski --- .../knn/index/query/KNNQueryBuilder.java | 30 ++++- .../knn/index/query/KNNQueryFactory.java | 103 +++++++++++++++++- .../knn/index/query/KNNQueryFactoryTests.java | 31 ++++++ 3 files changed, 157 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 1defe45e8..59b508d8f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -7,6 +7,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; @@ -38,6 +39,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField VECTOR_FIELD = new ParseField("vector"); public static final ParseField K_FIELD = new ParseField("k"); + public static final ParseField FILTER_FIELD = new ParseField("filter"); public static int K_MAX = 10000; /** * The name for the knn query @@ -49,6 +51,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; + private QueryBuilder filter; /** * Constructs a new knn query @@ -58,6 +61,10 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { * @param k K nearest neighbours for the given vector */ public KNNQueryBuilder(String fieldName, float[] vector, int k) { + this(fieldName, vector, k, null); + } + + public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); } @@ -77,6 +84,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) { this.fieldName = fieldName; this.vector = vector; this.k = k; + this.filter = filter; } public static void initialize(ModelDao modelDao) { @@ -111,6 +119,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; int k = 0; + QueryBuilder filter = null; String queryName = null; String currentFieldName = null; XContentParser.Token token; @@ -139,6 +148,14 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } + } else if (token == XContentParser.Token.START_OBJECT) { + String tokenName = parser.currentName(); + if (FILTER_FIELD.getPreferredName().equals(tokenName)) { + filter = parseInnerQueryBuilder(parser); + } else { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); + } + } else { throw new ParsingException( parser.getTokenLocation(), @@ -153,7 +170,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter); knnQueryBuilder.queryName(queryName); knnQueryBuilder.boost(boost); return knnQueryBuilder; @@ -226,7 +243,16 @@ protected Query doToQuery(QueryShardContext context) { } String indexName = context.index().getName(); - return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k); + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(this.vector) + .k(this.k) + .filter(this.filter) + .context(context) + .build(); + return KNNQueryFactory.create(createQueryRequest); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index cbdb03ea8..10c8edd1a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,11 +5,21 @@ package org.opensearch.knn.index.query; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.util.KNNEngine; +import java.io.IOException; +import java.util.Optional; + /** * Creates the Lucene k-NN queries */ @@ -27,14 +37,97 @@ public class KNNQueryFactory { * @return Lucene Query */ public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) { + final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(fieldName) + .vector(vector) + .k(k) + .build(); + return create(createQueryRequest); + } + + /** + * Creates a Lucene query for a particular engine. + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + public static Query create(CreateQueryRequest createQueryRequest) { // Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to // use the custom query type created by the plugin - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) { - log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + log.debug( + String.format( + "Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName(), + createQueryRequest.getK() + ) + ); + return new KNNQuery( + createQueryRequest.getFieldName(), + createQueryRequest.getVector(), + createQueryRequest.getK(), + createQueryRequest.getIndexName() + ); + } + + log.debug( + String.format( + "Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName(), + createQueryRequest.getK() + ) + ); + if (createQueryRequest.getFilter().isPresent()) { + final QueryShardContext queryShardContext = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + try { + final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); + return new KnnVectorQuery( + createQueryRequest.getFieldName(), + createQueryRequest.getVector(), + createQueryRequest.getK(), + filterQuery + ); + } catch (IOException e) { + throw new RuntimeException("Cannot create knn query with filter", e); + } + } + return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK()); + } + + /** + * DTO object to hold data required to create a Query instance. + */ + @AllArgsConstructor + @Builder + @Setter + static class CreateQueryRequest { + @Getter + @NonNull + private KNNEngine knnEngine; + @Getter + @NonNull + private String indexName; + @Getter + private String fieldName; + @Getter + private float[] vector; + @Getter + private int k; + // can be null in cases filter not passed with the knn query + private QueryBuilder filter; + // can be null in cases filter not passed with the knn query + private QueryShardContext context; + + public Optional getFilter() { + return Optional.ofNullable(filter); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnVectorQuery(fieldName, vector, k); + public Optional getContext() { + return Optional.ofNullable(context); + } } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 06b0ce6ca..908ea1021 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -7,6 +7,10 @@ import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.util.KNNEngine; @@ -14,6 +18,10 @@ import java.util.List; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class KNNQueryFactoryTests extends KNNTestCase { private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; @@ -42,4 +50,27 @@ public void testCreateLuceneDefaultQuery() { assertTrue(query instanceof KnnVectorQuery); } } + + public void testCreateLuceneQueryWithFilter() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + QueryBuilder filter = new TermQueryBuilder("foo", "fooval"); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(filter) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KnnVectorQuery); + } + } } From 47b9ad4d52d4c772d5719ba22dc3d7530a6381aa Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 7 Sep 2022 16:57:28 -0700 Subject: [PATCH 2/8] Adding more tests and logs (#538) Signed-off-by: Martin Gaievski --- .../knn/index/query/KNNQueryBuilder.java | 13 ++- .../knn/index/query/KNNQueryFactory.java | 42 +++---- .../opensearch/knn/index/LuceneEngineIT.java | 108 ++++++++++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 61 ++++++++++ 4 files changed, 194 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 59b508d8f..aeefdbff4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -151,11 +151,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } else if (token == XContentParser.Token.START_OBJECT) { String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { + log.debug(String.format("Start parsing filter for field [%s]", fieldName)); filter = parseInnerQueryBuilder(parser); } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } - } else { throw new ParsingException( parser.getTokenLocation(), @@ -201,6 +201,10 @@ public int getK() { return this.k; } + public QueryBuilder getFilter() { + return this.filter; + } + @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -208,6 +212,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio builder.field(VECTOR_FIELD.getPreferredName(), vector); builder.field(K_FIELD.getPreferredName(), k); + if (filter != null) { + builder.field(FILTER_FIELD.getPreferredName(), filter); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -242,6 +249,10 @@ protected Query doToQuery(QueryShardContext context) { ); } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) { + throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); + } + String indexName = context.index().getName(); KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 10c8edd1a..c68ce9502 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -55,47 +55,31 @@ public static Query create(KNNEngine knnEngine, String indexName, String fieldNa public static Query create(CreateQueryRequest createQueryRequest) { // Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to // use the custom query type created by the plugin + final String indexName = createQueryRequest.getIndexName(); + final String fieldName = createQueryRequest.getFieldName(); + final int k = createQueryRequest.getK(); + final float[] vector = createQueryRequest.getVector(); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { - log.debug( - String.format( - "Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", - createQueryRequest.getIndexName(), - createQueryRequest.getFieldName(), - createQueryRequest.getK() - ) - ); - return new KNNQuery( - createQueryRequest.getFieldName(), - createQueryRequest.getVector(), - createQueryRequest.getK(), - createQueryRequest.getIndexName() - ); + log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KNNQuery(fieldName, vector, k, indexName); } - log.debug( - String.format( - "Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", - createQueryRequest.getIndexName(), - createQueryRequest.getFieldName(), - createQueryRequest.getK() - ) - ); if (createQueryRequest.getFilter().isPresent()) { final QueryShardContext queryShardContext = createQueryRequest.getContext() .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + log.debug( + String.format("Creating Lucene k-NN query with filter for index [%s], field [%s] and k [%d]", indexName, fieldName, k) + ); try { final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - return new KnnVectorQuery( - createQueryRequest.getFieldName(), - createQueryRequest.getVector(), - createQueryRequest.getK(), - filterQuery - ); + return new KnnVectorQuery(fieldName, vector, k, filterQuery); } catch (IOException e) { throw new RuntimeException("Cannot create knn query with filter", e); } } - return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK()); + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KnnVectorQuery(fieldName, vector, k); } /** diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index fb26b893b..d0e3bae20 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -17,12 +17,14 @@ import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.rest.RestStatus; import java.io.IOException; import java.util.Arrays; @@ -33,14 +35,19 @@ import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; public class LuceneEngineIT extends KNNRestTestCase { private static final int DIMENSION = 3; private static final String DOC_ID = "doc1"; + private static final String DOC_ID_2 = "doc2"; + private static final String DOC_ID_3 = "doc3"; private static final int EF_CONSTRUCTION = 128; private static final String INDEX_NAME = "test-index-1"; private static final String FIELD_NAME = "test-field-1"; + private static final String COLOR_FIELD_NAME = "color"; + private static final String TASTE_FIELD_NAME = "taste"; private static final int M = 16; private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; @@ -246,6 +253,107 @@ public void testDeleteDoc() throws Exception { assertEquals(0, getDocCount(INDEX_NAME)); } + public void testQueryWithFilter() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshAllIndices(); + + final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + } + + public void testQuery_filterWithNonLuceneEngine() throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, NMSLIB_NAME) + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, mapping); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + final float[] searchVector = { 6.0f, 6.0f, 5.6f }; + int k = 5; + expectThrows( + ResponseException.class, + () -> searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + k + ) + ); + } + + private void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { + Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); + for (String fieldName : fieldValues.keySet()) { + builder.field(fieldName, fieldValues.get(fieldName)); + } + builder.endObject(); + request.setJsonEntity(Strings.toString(builder)); + client().performRequest(request); + + request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index c3d40cbc7..4ebcf9ec4 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -5,6 +5,14 @@ package org.opensearch.knn.index.query; +import com.google.common.collect.ImmutableMap; +import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.Query; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; @@ -12,16 +20,22 @@ import org.opensearch.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.plugins.SearchPlugin; import java.io.IOException; +import java.util.List; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; public class KNNQueryBuilderTests extends KNNTestCase { @@ -74,6 +88,36 @@ public void testFromXcontent() throws Exception { actualBuilder.equals(knnQueryBuilder); } + public void testFromXcontent_WithFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + actualBuilder.equals(knnQueryBuilder); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List list = ClusterModule.getNamedXWriteables(); + SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( + TermQueryBuilder.NAME, + TermQueryBuilder::new, + TermQueryBuilder::fromXContent + ); + list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); + NamedXContentRegistry registry = new NamedXContentRegistry(list); + return registry; + } + public void testDoToQuery_Normal() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); @@ -89,6 +133,23 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query instanceof KnnVectorQuery); + } + public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); From 2e18ae84eefd6feeb7891ab6b8171492bd6a448a Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 5 Oct 2022 09:57:25 -0700 Subject: [PATCH 3/8] Adding serialization for filter field in KnnQueryBuilder (#564) * Adding serialization/deserialization for filter field in Lucene knn query Signed-off-by: Martin Gaievski --- .../opensearch/knn/bwc/LuceneFilteringIT.java | 86 +++++++++++++ .../knn/index/KNNClusterContext.java | 69 ++++++++++ .../knn/index/query/KNNQueryBuilder.java | 33 ++++- .../org/opensearch/knn/plugin/KNNPlugin.java | 2 + .../knn/index/KNNClusterContextTests.java | 53 ++++++++ .../knn/index/KNNClusterTestUtils.java | 48 +++++++ .../knn/index/query/KNNQueryBuilderTests.java | 118 +++++++++++++++--- 7 files changed, 393 insertions(+), 16 deletions(-) create mode 100644 qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java create mode 100644 src/main/java/org/opensearch/knn/index/KNNClusterContext.java create mode 100644 src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java create mode 100644 src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java new file mode 100644 index 000000000..3ea611cbf --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.bwc; + +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; + +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; + +import java.io.IOException; + +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; + +/** + * Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine + */ +public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 50; + private static final int K = 10; + private static final int NUM_DOCS = 100; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100"); + + public void testLuceneFiltering() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0]; + switch (getClusterType()) { + case OLD: + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMappingWithLuceneField(TEST_FIELD, DIMENSIONS)); + bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS); + validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + break; + case MIXED: + validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + break; + case UPGRADED: + searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + deleteKNNIndex(testIndex); + break; + } + } + + protected String createKnnIndexMappingWithLuceneField(final String fieldName, int dimension) throws IOException { + return Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", Integer.toString(dimension)) + .startObject("method") + .field("name", "hnsw") + .field("engine", "lucene") + .field("space_type", "l2") + .endObject() + .endObject() + .endObject() + .endObject() + ); + } + + private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject().endObject(); + + Request request = new Request("POST", "/" + index + "/_search"); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + request.setJsonEntity(Strings.toString(builder)); + + expectThrows(ResponseException.class, () -> client().performRequest(request)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/KNNClusterContext.java b/src/main/java/org/opensearch/knn/index/KNNClusterContext.java new file mode 100644 index 000000000..a98cc8bea --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/KNNClusterContext.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import com.carrotsearch.hppc.cursors.ObjectCursor; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.ImmutableOpenMap; + +/** + * Class abstracts information related to underlying OpenSearch cluster + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@Log4j2 +public class KNNClusterContext { + + private ClusterService clusterService; + private static KNNClusterContext instance; + + /** + * Return instance of the cluster context, must be initialized first for proper usage + * @return instance of cluster context + */ + public static synchronized KNNClusterContext instance() { + if (instance == null) { + instance = new KNNClusterContext(); + } + return instance; + } + + /** + * Initializes instance of cluster context by injecting dependencies + * @param clusterService + */ + public void initialize(final ClusterService clusterService) { + this.clusterService = clusterService; + } + + /** + * Return minimal OpenSearch version based on all nodes currently discoverable in the cluster + * @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version + */ + public Version getClusterMinVersion() { + Version minVersion = Version.CURRENT; + ImmutableOpenMap clusterDiscoveryNodes = ImmutableOpenMap.of(); + log.debug("Reading cluster min version"); + try { + clusterDiscoveryNodes = this.clusterService.state().getNodes().getNodes(); + } catch (Exception exception) { + log.error("Cannot get cluster nodes", exception); + } + for (final ObjectCursor discoveryNodeCursor : clusterDiscoveryNodes.values()) { + final Version nodeVersion = discoveryNodeCursor.value.getVersion(); + if (nodeVersion.before(minVersion)) { + minVersion = nodeVersion; + log.debug("Update cluster min version to {} based on node {}", nodeVersion, discoveryNodeCursor.value.toString()); + } + } + log.debug("Return cluster min version {}", minVersion); + return minVersion; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index aeefdbff4..b94b40400 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -6,8 +6,10 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.KNNClusterContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; @@ -52,6 +54,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final float[] vector; private int k = 0; private QueryBuilder filter; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_3_0_0; /** * Constructs a new knn query @@ -109,8 +112,11 @@ public KNNQueryBuilder(StreamInput in) throws IOException { fieldName = in.readString(); vector = in.readFloatArray(); k = in.readInt(); + if (isClusterOnOrAfterMinRequiredVersion()) { + filter = in.readOptionalNamedWriteable(QueryBuilder.class); + } } catch (IOException ex) { - throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex); + throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } } @@ -152,7 +158,23 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); - filter = parseInnerQueryBuilder(parser); + if (isClusterOnOrAfterMinRequiredVersion()) { + filter = parseInnerQueryBuilder(parser); + } else { + log.debug( + String.format( + "This version of k-NN doesn't support [filter] field, minimal required version is [%s]", + MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + ) + ); + throw new IllegalArgumentException( + String.format( + "%s field is supported from version %s", + FILTER_FIELD.getPreferredName(), + MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + ) + ); + } } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } @@ -181,6 +203,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeFloatArray(vector); out.writeInt(k); + if (isClusterOnOrAfterMinRequiredVersion()) { + out.writeOptionalNamedWriteable(filter); + } } /** @@ -294,4 +319,8 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static boolean isClusterOnOrAfterMinRequiredVersion() { + return KNNClusterContext.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index c2564f179..c72198c7d 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -11,6 +11,7 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; +import org.opensearch.knn.index.KNNClusterContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -179,6 +180,7 @@ public Collection createComponents( NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); KNNSettings.state().initialize(client, clusterService); + KNNClusterContext.instance().initialize(clusterService); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java b/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java new file mode 100644 index 000000000..55e6bbde2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.KNNTestCase; + +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; + +public class KNNClusterContextTests extends KNNTestCase { + + public void testSingleNodeCluster() { + ClusterService clusterService = mockClusterService(List.of(Version.V_2_4_0)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + final Version minVersion = knnClusterContext.getClusterMinVersion(); + + assertTrue(Version.V_2_4_0.equals(minVersion)); + } + + public void testMultipleNodesCluster() { + ClusterService clusterService = mockClusterService(List.of(Version.V_3_0_0, Version.V_2_3_0, Version.V_3_0_0)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + final Version minVersion = knnClusterContext.getClusterMinVersion(); + + assertTrue(Version.V_2_3_0.equals(minVersion)); + } + + public void testWhenErrorOnClusterStateDiscover() { + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + final Version minVersion = knnClusterContext.getClusterMinVersion(); + + assertTrue(Version.CURRENT.equals(minVersion)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java new file mode 100644 index 000000000..f58584898 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.ImmutableOpenMap; + +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; + +/** + * Collection of util methods required for testing and related to OpenSearch cluster setup and functionality + */ +public class KNNClusterTestUtils { + + /** + * Create new mock for ClusterService + * @param versions list of versions for cluster nodes + * @return + */ + public static ClusterService mockClusterService(final List versions) { + ClusterService clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + when(clusterService.state()).thenReturn(clusterState); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); + for (Version version : versions) { + DiscoveryNode clusterNode = mock(DiscoveryNode.class); + when(clusterNode.getVersion()).thenReturn(version); + builder.put(randomAlphaOfLength(10), clusterNode); + } + ImmutableOpenMap mapOfNodes = builder.build(); + when(discoveryNodes.getNodes()).thenReturn(mapOfNodes); + + return clusterService; + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 4ebcf9ec4..435987f7e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -8,7 +8,13 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -20,6 +26,7 @@ import org.opensearch.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.KNNClusterContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -31,31 +38,38 @@ import java.io.IOException; import java.util.List; +import java.util.Optional; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; public class KNNQueryBuilderTests extends KNNTestCase { + private static final String FIELD_NAME = "myvector"; + private static final int K = 1; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + public void testInvalidK() { float[] queryVector = { 1.0f, 1.0f }; /** * -ve k */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, -1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, -K)); /** * zero k */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 0)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, 0)); /** * k > KNNQueryBuilder.K_MAX */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, KNNQueryBuilder.K_MAX + 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, KNNQueryBuilder.K_MAX + K)); } public void testEmptyVector() { @@ -63,18 +77,18 @@ public void testEmptyVector() { * null query vector */ float[] queryVector = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, K)); /** * empty query vector */ float[] queryVector1 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector1, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K)); } public void testFromXcontent() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -89,8 +103,13 @@ public void testFromXcontent() throws Exception { } public void testFromXcontent_WithFilter() throws Exception { + final ClusterService clusterService = mockClusterService(List.of(Version.CURRENT)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -105,6 +124,28 @@ public void testFromXcontent_WithFilter() throws Exception { actualBuilder.equals(knnQueryBuilder); } + public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception { + final ClusterService clusterService = mockClusterService(List.of(Version.V_2_3_0)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + final XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + + expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser)); + } + @Override protected NamedXContentRegistry xContentRegistry() { List list = ClusterModule.getNamedXWriteables(); @@ -118,9 +159,17 @@ protected NamedXContentRegistry xContentRegistry() { return registry; } + @Override + protected NamedWriteableRegistry writableRegistry() { + final List entries = ClusterModule.getNamedWriteables(); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KNNQueryBuilder.NAME, KNNQueryBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + return new NamedWriteableRegistry(entries); + } + public void testDoToQuery_Normal() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -135,7 +184,7 @@ public void testDoToQuery_Normal() throws Exception { public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -152,14 +201,14 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension - when(mockKNNVectorField.getDimension()).thenReturn(-1); + when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; when(mockKNNVectorField.getModelId()).thenReturn(modelId); @@ -181,7 +230,7 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_InvalidDimensions() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -189,13 +238,13 @@ public void testDoToQuery_InvalidDimensions() { when(mockKNNVectorField.getDimension()).thenReturn(400); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - when(mockKNNVectorField.getDimension()).thenReturn(1); + when(mockKNNVectorField.getDimension()).thenReturn(K); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } public void testDoToQuery_InvalidFieldType() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); NumberFieldMapper.NumberFieldType mockNumberField = mock(NumberFieldMapper.NumberFieldType.class); @@ -203,4 +252,45 @@ public void testDoToQuery_InvalidFieldType() throws IOException { when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockNumberField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + + public void testSerialization() throws Exception { + assertSerialization(Version.CURRENT, Optional.empty()); + + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY)); + + assertSerialization(Version.V_2_3_0, Optional.empty()); + } + + private void assertSerialization(final Version version, final Optional queryBuilderOptional) throws Exception { + final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + + final ClusterService clusterService = mockClusterService(List.of(version)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + output.writeNamedWriteable(knnQueryBuilder); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + in.setVersion(Version.CURRENT); + final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + + assertNotNull(deserializedQuery); + assertTrue(deserializedQuery instanceof KNNQueryBuilder); + final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery; + assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); + assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); + assertEquals(K, deserializedKnnQueryBuilder.getK()); + if (queryBuilderOptional.isPresent()) { + assertNotNull(deserializedKnnQueryBuilder.getFilter()); + assertEquals(queryBuilderOptional.get(), deserializedKnnQueryBuilder.getFilter()); + } else { + assertNull(deserializedKnnQueryBuilder.getFilter()); + } + } + } + } } From 9b32e17f6b673e3435978f1e3f1289ce1beaa947 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 14 Oct 2022 14:10:02 -0700 Subject: [PATCH 4/8] Read min cluster version directly from DiscoveryNodes (#581) * Simplify min cluster version lookup Signed-off-by: Martin Gaievski --- .../knn/index/KNNClusterContext.java | 23 +++++-------------- .../knn/index/KNNClusterContextTests.java | 6 ++--- .../knn/index/KNNClusterTestUtils.java | 19 +++------------ .../knn/index/query/KNNQueryBuilderTests.java | 6 ++--- 4 files changed, 14 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNClusterContext.java b/src/main/java/org/opensearch/knn/index/KNNClusterContext.java index a98cc8bea..938ed8302 100644 --- a/src/main/java/org/opensearch/knn/index/KNNClusterContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNClusterContext.java @@ -5,14 +5,11 @@ package org.opensearch.knn.index; -import com.carrotsearch.hppc.cursors.ObjectCursor; import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.log4j.Log4j2; import org.opensearch.Version; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.collect.ImmutableOpenMap; /** * Class abstracts information related to underlying OpenSearch cluster @@ -48,22 +45,14 @@ public void initialize(final ClusterService clusterService) { * @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version */ public Version getClusterMinVersion() { - Version minVersion = Version.CURRENT; - ImmutableOpenMap clusterDiscoveryNodes = ImmutableOpenMap.of(); - log.debug("Reading cluster min version"); try { - clusterDiscoveryNodes = this.clusterService.state().getNodes().getNodes(); + return this.clusterService.state().getNodes().getMinNodeVersion(); } catch (Exception exception) { - log.error("Cannot get cluster nodes", exception); + log.error( + String.format("Failed to get cluster minimum node version, returning current node version %s instead.", Version.CURRENT), + exception + ); + return Version.CURRENT; } - for (final ObjectCursor discoveryNodeCursor : clusterDiscoveryNodes.values()) { - final Version nodeVersion = discoveryNodeCursor.value.getVersion(); - if (nodeVersion.before(minVersion)) { - minVersion = nodeVersion; - log.debug("Update cluster min version to {} based on node {}", nodeVersion, discoveryNodeCursor.value.toString()); - } - } - log.debug("Return cluster min version {}", minVersion); - return minVersion; } } diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java b/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java index 55e6bbde2..5c8ed970e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java @@ -9,8 +9,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; -import java.util.List; - import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; @@ -18,7 +16,7 @@ public class KNNClusterContextTests extends KNNTestCase { public void testSingleNodeCluster() { - ClusterService clusterService = mockClusterService(List.of(Version.V_2_4_0)); + ClusterService clusterService = mockClusterService(Version.V_2_4_0); final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); knnClusterContext.initialize(clusterService); @@ -29,7 +27,7 @@ public void testSingleNodeCluster() { } public void testMultipleNodesCluster() { - ClusterService clusterService = mockClusterService(List.of(Version.V_3_0_0, Version.V_2_3_0, Version.V_3_0_0)); + ClusterService clusterService = mockClusterService(Version.V_2_3_0); final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); knnClusterContext.initialize(clusterService); diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java index f58584898..6ded05d17 100644 --- a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java +++ b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java @@ -7,16 +7,11 @@ import org.opensearch.Version; import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.collect.ImmutableOpenMap; - -import java.util.List; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; /** * Collection of util methods required for testing and related to OpenSearch cluster setup and functionality @@ -25,24 +20,16 @@ public class KNNClusterTestUtils { /** * Create new mock for ClusterService - * @param versions list of versions for cluster nodes + * @param version min version for cluster nodes * @return */ - public static ClusterService mockClusterService(final List versions) { + public static ClusterService mockClusterService(final Version version) { ClusterService clusterService = mock(ClusterService.class); ClusterState clusterState = mock(ClusterState.class); when(clusterService.state()).thenReturn(clusterState); DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); when(clusterState.getNodes()).thenReturn(discoveryNodes); - ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); - for (Version version : versions) { - DiscoveryNode clusterNode = mock(DiscoveryNode.class); - when(clusterNode.getVersion()).thenReturn(version); - builder.put(randomAlphaOfLength(10), clusterNode); - } - ImmutableOpenMap mapOfNodes = builder.build(); - when(discoveryNodes.getNodes()).thenReturn(mapOfNodes); - + when(discoveryNodes.getMinNodeVersion()).thenReturn(version); return clusterService; } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 435987f7e..3be622da0 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -103,7 +103,7 @@ public void testFromXcontent() throws Exception { } public void testFromXcontent_WithFilter() throws Exception { - final ClusterService clusterService = mockClusterService(List.of(Version.CURRENT)); + final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); knnClusterContext.initialize(clusterService); @@ -125,7 +125,7 @@ public void testFromXcontent_WithFilter() throws Exception { } public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception { - final ClusterService clusterService = mockClusterService(List.of(Version.V_2_3_0)); + final ClusterService clusterService = mockClusterService(Version.V_2_3_0); final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); knnClusterContext.initialize(clusterService); @@ -266,7 +266,7 @@ private void assertSerialization(final Version version, final Optional Date: Wed, 19 Oct 2022 09:34:13 -0700 Subject: [PATCH 5/8] Refactor kNN codec related classes (#582) * Refactor codec related classes, create KNNCodecVersion abstraction Signed-off-by: Martin Gaievski --- .../codec/BasePerFieldKnnVectorsFormat.java | 79 ++++++++++++++++ .../index/codec/KNN910Codec/KNN910Codec.java | 13 +-- .../index/codec/KNN920Codec/KNN920Codec.java | 16 +--- .../KNN920PerFieldKnnVectorsFormat.java | 70 ++------------ .../index/codec/KNN940Codec/KNN940Codec.java | 14 +-- .../KNN940PerFieldKnnVectorsFormat.java | 66 ++------------ .../knn/index/codec/KNNCodecFactory.java | 51 ----------- .../knn/index/codec/KNNCodecService.java | 7 +- .../knn/index/codec/KNNCodecVersion.java | 91 +++++++++++++++++++ .../knn/index/codec/KNNFormatFactory.java | 53 ----------- .../codec/KNN920Codec/KNN920CodecTests.java | 6 +- .../codec/KNN940Codec/KNN940CodecTests.java | 8 +- .../knn/index/codec/KNNCodecFactoryTests.java | 43 ++++----- .../knn/index/codec/KNNCodecTestCase.java | 3 +- .../index/codec/KNNFormatFactoryTests.java | 51 ----------- 15 files changed, 237 insertions(+), 334 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java delete mode 100644 src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java new file mode 100644 index 000000000..d10ad9821 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; + +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version + */ +@AllArgsConstructor +@Log4j2 +public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat { + + private final Optional mapperService; + private final int defaultMaxConnections; + private final int defaultBeamWidth; + private final Supplier defaultFormatSupplier; + private final BiFunction formatSupplier; + + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { + if (isKnnVectorFieldType(field) == false) { + log.debug( + "Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"", + field, + defaultMaxConnections, + defaultBeamWidth + ); + return defaultFormatSupplier.get(); + } + var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( + () -> new IllegalStateException( + String.format("Cannot read field type for field [%s] because mapper service is not available", field) + ) + ).fieldType(field); + var params = type.getKnnMethodContext().getMethodComponent().getParameters(); + int maxConnections = getMaxConnections(params); + int beamWidth = getBeamWidth(params); + log.debug( + "Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"", + field, + maxConnections, + beamWidth + ); + return formatSupplier.apply(maxConnections, beamWidth); + } + + private boolean isKnnVectorFieldType(final String field) { + return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType; + } + + private int getMaxConnections(final Map params) { + if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { + return (int) params.get(KNNConstants.METHOD_PARAMETER_M); + } + return defaultMaxConnections; + } + + private int getBeamWidth(final Map params) { + if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { + return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); + } + return defaultBeamWidth; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java index 0acaccfbf..77783dc29 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java @@ -8,10 +8,8 @@ import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; +import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; -import org.opensearch.knn.index.codec.KNNFormatFactory; - -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate; /** * Extends the Codec to support a new file format for KNN index @@ -19,15 +17,14 @@ * */ public final class KNN910Codec extends FilterCodec { - - private static final String KNN910 = "KNN910Codec"; + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_1_0; private final KNNFormatFacade knnFormatFacade; /** * No arg constructor that uses Lucene91 as the delegate */ public KNN910Codec() { - this(createKNN91DefaultDelegate()); + this(VERSION.getDefaultCodecDelegate()); } /** @@ -36,8 +33,8 @@ public KNN910Codec() { * @param delegate codec that will perform all operations this codec does not override */ public KNN910Codec(Codec delegate) { - super(KNN910, delegate); - knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate); + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java index 26abcea60..b79c1b4f2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java @@ -12,21 +12,15 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; -import org.opensearch.knn.index.codec.KNNFormatFactory; - -import java.util.Optional; - -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate; /** * KNN codec that is based on Lucene92 codec */ @Log4j2 public final class KNN920Codec extends FilterCodec { - - private static final String KNN920 = "KNN920Codec"; - + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_2_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; @@ -34,7 +28,7 @@ public final class KNN920Codec extends FilterCodec { * No arg constructor that uses Lucene91 as the delegate */ public KNN920Codec() { - this(createKNN92DefaultDelegate(), new KNN920PerFieldKnnVectorsFormat(Optional.empty())); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); } /** @@ -45,8 +39,8 @@ public KNN920Codec() { */ @Builder public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { - super(KNN920, delegate); - knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate); + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java index 0286e829a..ae1ef206c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java @@ -5,74 +5,24 @@ package org.opensearch.knn.index.codec.KNN920Codec; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; -import java.util.Map; import java.util.Optional; /** * Class provides per field format implementation for Lucene Knn vector type */ -@AllArgsConstructor -@Log4j2 -public class KNN920PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat { - - private final Optional mapperService; - - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { - if (isNotKnnVectorFieldType(field)) { - log.debug( - String.format( - "Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"", - field, - Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN, - Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH - ) - ); - return new Lucene92HnswVectorsFormat(); - } - var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( - () -> new IllegalStateException( - String.format("Cannot read field type for field [%s] because mapper service is not available", field) - ) - ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponent().getParameters(); - int maxConnections = getMaxConnections(params); - int beamWidth = getBeamWidth(params); - log.debug( - String.format( - "Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"", - field, - maxConnections, - beamWidth - ) +public class KNN920PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + + public KNN920PerFieldKnnVectorsFormat(final Optional mapperService) { + super( + mapperService, + Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + () -> new Lucene92HnswVectorsFormat(), + (maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth) ); - return new Lucene92HnswVectorsFormat(maxConnections, beamWidth); - } - - private boolean isNotKnnVectorFieldType(final String field) { - return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType); - } - - private int getMaxConnections(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_M); - } - return Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN; - } - - private int getBeamWidth(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); - } - return Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java index 43a348cee..a056581d6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java @@ -12,15 +12,11 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; -import org.opensearch.knn.index.codec.KNNFormatFactory; - -import java.util.Optional; - -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate; public class KNN940Codec extends FilterCodec { - private static final String KNN940 = "KNN940Codec"; + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_4_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; @@ -28,7 +24,7 @@ public class KNN940Codec extends FilterCodec { * No arg constructor that uses Lucene94 as the delegate */ public KNN940Codec() { - this(createKNN94DefaultDelegate(), new KNN940PerFieldKnnVectorsFormat(Optional.empty())); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); } /** @@ -40,8 +36,8 @@ public KNN940Codec() { */ @Builder protected KNN940Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { - super(KNN940, delegate); - knnFormatFacade = KNNFormatFactory.createKNN940Format(delegate); + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java index feb819fdf..d80c757c9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java @@ -5,70 +5,24 @@ package org.opensearch.knn.index.codec.KNN940Codec; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; -import java.util.Map; import java.util.Optional; /** * Class provides per field format implementation for Lucene Knn vector type */ -@AllArgsConstructor -@Log4j2 -public class KNN940PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat { - - private final Optional mapperService; - - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { - if (isNotKnnVectorFieldType(field)) { - log.debug( - "Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"", - field, - Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN, - Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH - ); - return new Lucene94HnswVectorsFormat(); - } - var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( - () -> new IllegalStateException( - String.format("Cannot read field type for field [%s] because mapper service is not available", field) - ) - ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponent().getParameters(); - int maxConnections = getMaxConnections(params); - int beamWidth = getBeamWidth(params); - log.debug( - "Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"", - field, - maxConnections, - beamWidth +public class KNN940PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + + public KNN940PerFieldKnnVectorsFormat(final Optional mapperService) { + super( + mapperService, + Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + () -> new Lucene94HnswVectorsFormat(), + (maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth) ); - return new Lucene94HnswVectorsFormat(maxConnections, beamWidth); - } - - private boolean isNotKnnVectorFieldType(final String field) { - return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType); - } - - private int getMaxConnections(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_M); - } - return Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN; - } - - private int getBeamWidth(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); - } - return Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java deleted file mode 100644 index e53e1dd2a..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.knn.index.codec; - -import lombok.AllArgsConstructor; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; -import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; -import org.apache.lucene.codecs.lucene94.Lucene94Codec; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940PerFieldKnnVectorsFormat; - -import java.util.Optional; - -/** - * Factory abstraction for KNN codec - */ -@AllArgsConstructor -public class KNNCodecFactory { - - private final MapperService mapperService; - - public Codec createKNNCodec(final Codec userCodec) { - var codec = KNN940Codec.builder() - .delegate(userCodec) - .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.of(mapperService))) - .build(); - return codec; - } - - /** - * Factory abstraction for codec delegate - */ - public static class CodecDelegateFactory { - - public static Codec createKNN91DefaultDelegate() { - return new Lucene91Codec(); - } - - public static Codec createKNN92DefaultDelegate() { - return new Lucene92Codec(); - } - - public static Codec createKNN94DefaultDelegate() { - return new Lucene94Codec(); - } - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 8ce5e6928..d56e09a3f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -8,17 +8,18 @@ import org.opensearch.index.codec.CodecServiceConfig; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; +import org.opensearch.index.mapper.MapperService; /** * KNNCodecService to inject the right KNNCodec version */ public class KNNCodecService extends CodecService { - private final KNNCodecFactory knnCodecFactory; + private final MapperService mapperService; public KNNCodecService(CodecServiceConfig codecServiceConfig) { super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger()); - knnCodecFactory = new KNNCodecFactory(codecServiceConfig.getMapperService()); + mapperService = codecServiceConfig.getMapperService(); } /** @@ -29,6 +30,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return knnCodecFactory.createKNNCodec(super.codec(name)); + return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java new file mode 100644 index 000000000..adbbb01ca --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; +import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene94.Lucene94Codec; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; +import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; +import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec; +import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; +import org.opensearch.knn.index.codec.KNN940Codec.KNN940PerFieldKnnVectorsFormat; + +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * Abstraction for k-NN codec version, aggregates all details for specific version such as codec name, corresponding + * Lucene codec, formats including one for k-NN vector etc. + */ +@AllArgsConstructor +@Getter +public enum KNNCodecVersion { + + V_9_1_0( + "KNN910Codec", + new Lucene91Codec(), + null, + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> new KNN910Codec(userCodec), + KNN910Codec::new + ), + + V_9_2_0( + "KNN920Codec", + new Lucene92Codec(), + new KNN920PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN920Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN920PerFieldKnnVectorsFormat(Optional.of(mapperService))) + .build(), + KNN920Codec::new + ), + + V_9_4_0( + "KNN940Codec", + new Lucene94Codec(), + new KNN940PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN940Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.of(mapperService))) + .build(), + KNN940Codec::new + ); + + private static final KNNCodecVersion CURRENT = V_9_4_0; + + private final String codecName; + private final Codec defaultCodecDelegate; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final Function knnFormatFacadeSupplier; + private final BiFunction knnCodecSupplier; + private final Supplier defaultKnnCodecSupplier; + + public static final KNNCodecVersion current() { + return CURRENT; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java b/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java deleted file mode 100644 index ee17189e3..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.knn.index.codec; - -import org.apache.lucene.codecs.Codec; -import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; -import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; - -/** - * Factory abstraction for KNN format facade creation - */ -public class KNNFormatFactory { - - /** - * Return facade class that abstracts format specific to KNN910 codec - * @param delegate delegate codec that is wrapped by KNN codec - * @return - */ - public static KNNFormatFacade createKNN910Format(final Codec delegate) { - final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( - new KNN80DocValuesFormat(delegate.docValuesFormat()), - new KNN80CompoundFormat(delegate.compoundFormat()) - ); - return knnFormatFacade; - } - - /** - * Return facade class that abstracts format specific to KNN920 codec - * @param delegate delegate codec that is wrapped by KNN codec - * @return - */ - public static KNNFormatFacade createKNN920Format(final Codec delegate) { - final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( - new KNN80DocValuesFormat(delegate.docValuesFormat()), - new KNN80CompoundFormat(delegate.compoundFormat()) - ); - return knnFormatFacade; - } - - /** - * Return facade class that abstracts format specific to KNN940 codec - * @param delegate delegate codec that is wrapped by KNN codec - */ - public static KNNFormatFacade createKNN940Format(final Codec delegate) { - final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( - new KNN80DocValuesFormat(delegate.docValuesFormat()), - new KNN80CompoundFormat(delegate.compoundFormat()) - ); - return knnFormatFacade; - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java index 06cc7fad8..8cdfc2d69 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java @@ -9,15 +9,15 @@ import java.io.IOException; import java.util.concurrent.ExecutionException; -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_2_0; public class KNN920CodecTests extends KNNCodecTestCase { public void testMultiFieldsKnnIndex() throws Exception { - testMultiFieldsKnnIndex(KNN920Codec.builder().delegate(createKNN92DefaultDelegate()).build()); + testMultiFieldsKnnIndex(KNN920Codec.builder().delegate(V_9_2_0.getDefaultCodecDelegate()).build()); } public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { - testBuildFromModelTemplate((KNN920Codec.builder().delegate(createKNN92DefaultDelegate()).build())); + testBuildFromModelTemplate((KNN920Codec.builder().delegate(V_9_2_0.getDefaultCodecDelegate()).build())); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java index 578f88f9f..1101d93bb 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java @@ -14,16 +14,16 @@ import java.util.concurrent.ExecutionException; import java.util.function.Function; -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_4_0; public class KNN940CodecTests extends KNNCodecTestCase { public void testMultiFieldsKnnIndex() throws Exception { - testMultiFieldsKnnIndex(KNN940Codec.builder().delegate(createKNN94DefaultDelegate()).build()); + testMultiFieldsKnnIndex(KNN940Codec.builder().delegate(V_9_4_0.getDefaultCodecDelegate()).build()); } public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { - testBuildFromModelTemplate((KNN940Codec.builder().delegate(createKNN94DefaultDelegate()).build())); + testBuildFromModelTemplate((KNN940Codec.builder().delegate(V_9_4_0.getDefaultCodecDelegate()).build())); } public void testKnnVectorIndex() throws Exception { @@ -31,7 +31,7 @@ public void testKnnVectorIndex() throws Exception { mapperService) -> new KNN940PerFieldKnnVectorsFormat(Optional.of(mapperService)); Function knnCodecProvider = (knnVectorFormat) -> KNN940Codec.builder() - .delegate(createKNN94DefaultDelegate()) + .delegate(V_9_4_0.getDefaultCodecDelegate()) .knnVectorsFormat(knnVectorFormat) .build(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java index 6e1c96bcb..d918f5439 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java @@ -5,42 +5,39 @@ package org.opensearch.knn.index.codec; +import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; -import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.lucene94.Lucene94Codec; -import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; -import static org.mockito.Mockito.mock; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_1_0; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_2_0; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_4_0; public class KNNCodecFactoryTests extends KNNTestCase { - public void testKNN91DefaultDelegate() { - Codec knn91DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate(); - assertNotNull(knn91DefaultDelegate); - assertTrue(knn91DefaultDelegate instanceof Lucene91Codec); + public void testKNN910Codec() { + assertDelegateForVersion(V_9_1_0, Lucene91Codec.class); + assertNull(V_9_1_0.getPerFieldKnnVectorsFormat()); + assertNotNull(V_9_1_0.getKnnFormatFacadeSupplier().apply(V_9_1_0.getDefaultCodecDelegate())); } - public void testKNN92DefaultDelegate() { - Codec knn92DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate(); - assertNotNull(knn92DefaultDelegate); - assertTrue(knn92DefaultDelegate instanceof Lucene92Codec); + public void testKNN920Codec() { + assertDelegateForVersion(V_9_2_0, Lucene92Codec.class); + assertNotNull(V_9_2_0.getPerFieldKnnVectorsFormat()); + assertNotNull(V_9_2_0.getKnnFormatFacadeSupplier().apply(V_9_2_0.getDefaultCodecDelegate())); } - public void testKNN94DefaultDelegate() { - Codec knn94DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate(); - assertNotNull(knn94DefaultDelegate); - assertTrue(knn94DefaultDelegate instanceof Lucene94Codec); + public void testKNN940Codec() { + assertDelegateForVersion(V_9_4_0, Lucene94Codec.class); + assertNotNull(V_9_4_0.getPerFieldKnnVectorsFormat()); + assertNotNull(V_9_4_0.getKnnFormatFacadeSupplier().apply(V_9_4_0.getDefaultCodecDelegate())); } - public void testKNNDefaultCodec() { - MapperService mapperService = mock(MapperService.class); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - Codec knnCodec = knnCodecFactory.createKNNCodec(KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate()); - assertNotNull(knnCodec); - assertTrue(knnCodec instanceof KNN940Codec); - assertEquals("KNN940Codec", knnCodec.getName()); + private void assertDelegateForVersion(final KNNCodecVersion codecVersion, final Class expectedCodecClass) { + final Codec defaultDelegate = codecVersion.getDefaultCodecDelegate(); + assertNotNull(defaultDelegate); + assertTrue(defaultDelegate.getClass().isAssignableFrom(expectedCodecClass)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 623f2dc74..43ae19320 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -19,7 +19,6 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.query.KNNQueryFactory; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.KNNSettings; @@ -79,7 +78,7 @@ */ public class KNNCodecTestCase extends KNNTestCase { - private static final KNN940Codec ACTUAL_CODEC = new KNN940Codec(); + private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); private static FieldType sampleFieldType; static { sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java deleted file mode 100644 index bdaed33e4..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec; - -import org.apache.lucene.codecs.Codec; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.KNNTestCase; - -import static org.mockito.Mockito.mock; - -public class KNNFormatFactoryTests extends KNNTestCase { - - public void testKNN91Format() { - final Codec lucene91CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate(); - MapperService mapperService = mock(MapperService.class); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - final Codec knnCodec = knnCodecFactory.createKNNCodec(lucene91CodecDelegate); - KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN910Format(knnCodec); - - assertNotNull(knnFormatFacade); - assertNotNull(knnFormatFacade.compoundFormat()); - assertNotNull(knnFormatFacade.docValuesFormat()); - } - - public void testKNN92Format() { - MapperService mapperService = mock(MapperService.class); - final Codec lucene92CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate(); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - final Codec knnCodec = knnCodecFactory.createKNNCodec(lucene92CodecDelegate); - KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN920Format(knnCodec); - - assertNotNull(knnFormatFacade); - assertNotNull(knnFormatFacade.compoundFormat()); - assertNotNull(knnFormatFacade.docValuesFormat()); - } - - public void testKNN94Format() { - MapperService mapperService = mock(MapperService.class); - Codec lucene94CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate(); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - Codec knnCodec = knnCodecFactory.createKNNCodec(lucene94CodecDelegate); - KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN940Format(knnCodec); - - assertNotNull(knnFormatFacade); - assertNotNull(knnFormatFacade.compoundFormat()); - assertNotNull(knnFormatFacade.docValuesFormat()); - } -} From 52e2b6b7e8611c3d1dce0683215bb2c320a25945 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 20 Oct 2022 10:34:26 -0700 Subject: [PATCH 6/8] Adding stat for query with filter (#587) Signed-off-by: Martin Gaievski --- .../knn/index/query/KNNQueryBuilder.java | 1 + .../knn/plugin/stats/KNNCounter.java | 3 ++- .../knn/plugin/stats/KNNStatsConfig.java | 4 ++++ .../knn/plugin/stats/StatNames.java | 3 ++- .../plugin/action/RestKNNStatsHandlerIT.java | 20 +++++++++++++++++++ 5 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index b94b40400..a27baca3f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -158,6 +158,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); + KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment(); if (isClusterOnOrAfterMinRequiredVersion()) { filter = parseInnerQueryBuilder(parser); } else { diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java index d933ce66d..ce04c9078 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java @@ -21,7 +21,8 @@ public enum KNNCounter { SCRIPT_QUERY_REQUESTS("script_query_requests"), SCRIPT_QUERY_ERRORS("script_query_errors"), TRAINING_REQUESTS("training_requests"), - TRAINING_ERRORS("training_errors"); + TRAINING_ERRORS("training_errors"), + KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"); private String name; private AtomicLong count; diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java index c41170b32..8769e0e46 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java @@ -55,6 +55,10 @@ public class KNNStatsConfig { .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNCircuitBreakerSupplier())) .put(StatNames.MODEL_INDEX_STATUS.getName(), new KNNStat<>(true, new ModelIndexStatusSupplier<>(ModelDao::getHealthStatus))) .put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS))) + .put( + StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS)) + ) .put(StatNames.SCRIPT_COMPILATIONS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATIONS))) .put( StatNames.SCRIPT_COMPILATION_ERRORS.getName(), diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index ffe5882bb..a098dd8b5 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -40,7 +40,8 @@ public enum StatNames { TRAINING_ERRORS(KNNCounter.TRAINING_ERRORS.getName()), TRAINING_MEMORY_USAGE("training_memory_usage"), TRAINING_MEMORY_USAGE_PERCENTAGE("training_memory_usage_percentage"), - SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()); + SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()), + KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); private String name; diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 6d5b01daf..a454fedb3 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -22,6 +22,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -67,6 +68,7 @@ public class RestKNNStatsHandlerIT extends KNNRestTestCase { private boolean isDebuggingRemoteCluster = System.getProperty("cluster.debug", "false").equals("true"); private static final String FIELD_NAME_2 = "test_field_two"; private static final String FIELD_NAME_3 = "test_field_three"; + private static final String FIELD_LUCENE_NAME = "lucene_test_field"; private static final int DIMENSION = 4; private static int DOC_ID = 0; private static final int NUM_DOCS = 10; @@ -106,6 +108,7 @@ public void testStatsValueCheck() throws Exception { Map nodeStats0 = parseNodeStatsResponse(responseBody).get(0); Integer hitCount0 = (Integer) nodeStats0.get(StatNames.HIT_COUNT.getName()); Integer missCount0 = (Integer) nodeStats0.get(StatNames.MISS_COUNT.getName()); + Integer knnQueryWithFilterCount0 = (Integer) nodeStats0.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); // Setup index createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); @@ -124,9 +127,11 @@ public void testStatsValueCheck() throws Exception { Map nodeStats1 = parseNodeStatsResponse(responseBody).get(0); Integer hitCount1 = (Integer) nodeStats1.get(StatNames.HIT_COUNT.getName()); Integer missCount1 = (Integer) nodeStats1.get(StatNames.MISS_COUNT.getName()); + Integer knnQueryWithFilterCount1 = (Integer) nodeStats1.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); assertEquals(hitCount0, hitCount1); assertEquals((Integer) (missCount0 + 1), missCount1); + assertEquals(knnQueryWithFilterCount0, knnQueryWithFilterCount1); // Second search: Ensure that hits=1 searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); @@ -137,9 +142,24 @@ public void testStatsValueCheck() throws Exception { Map nodeStats2 = parseNodeStatsResponse(responseBody).get(0); Integer hitCount2 = (Integer) nodeStats2.get(StatNames.HIT_COUNT.getName()); Integer missCount2 = (Integer) nodeStats2.get(StatNames.MISS_COUNT.getName()); + Integer knnQueryWithFilterCount2 = (Integer) nodeStats2.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); assertEquals(missCount1, missCount2); assertEquals((Integer) (hitCount1 + 1), hitCount2); + assertEquals(knnQueryWithFilterCount0, knnQueryWithFilterCount2); + + putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_LUCENE_NAME, 2, METHOD_HNSW, LUCENE_NAME)); + addKnnDoc(INDEX_NAME, "2", FIELD_LUCENE_NAME, vector); + + searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_LUCENE_NAME, qvector, 1, QueryBuilders.termQuery("_id", "1")), 1); + + response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + responseBody = EntityUtils.toString(response.getEntity()); + + Map nodeStats3 = parseNodeStatsResponse(responseBody).get(0); + Integer knnQueryWithFilterCount3 = (Integer) nodeStats3.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); + + assertEquals((Integer) (knnQueryWithFilterCount0 + 1), knnQueryWithFilterCount3); } /** From a2b92b1daff474f142aa26332994e2e7da2aecd7 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 24 Oct 2022 11:04:33 -0700 Subject: [PATCH 7/8] Rename context class, adjust lucene IT Signed-off-by: Martin Gaievski --- .../opensearch/knn/bwc/LuceneFilteringIT.java | 42 +++++++++---------- ...lusterContext.java => KNNClusterUtil.java} | 8 ++-- .../knn/index/query/KNNQueryBuilder.java | 4 +- .../org/opensearch/knn/plugin/KNNPlugin.java | 4 +- ...extTests.java => KNNClusterUtilTests.java} | 20 ++++----- .../knn/index/query/KNNQueryBuilderTests.java | 14 +++---- 6 files changed, 46 insertions(+), 46 deletions(-) rename src/main/java/org/opensearch/knn/index/{KNNClusterContext.java => KNNClusterUtil.java} (89%) rename src/test/java/org/opensearch/knn/index/{KNNClusterContextTests.java => KNNClusterUtilTests.java} (60%) diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java index 3ea611cbf..3a7d0329d 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.bwc; +import org.hamcrest.MatcherAssert; import org.opensearch.knn.TestUtils; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -19,7 +20,11 @@ import java.io.IOException; +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.CoreMatchers.containsString; import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; /** * Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine @@ -36,7 +41,11 @@ public void testLuceneFiltering() throws Exception { float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0]; switch (getClusterType()) { case OLD: - createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMappingWithLuceneField(TEST_FIELD, DIMENSIONS)); + createKnnIndex( + testIndex, + getKNNDefaultIndexSettings(), + createKnnIndexMapping(TEST_FIELD, DIMENSIONS, METHOD_HNSW, LUCENE_NAME) + ); bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS); validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); break; @@ -50,25 +59,6 @@ public void testLuceneFiltering() throws Exception { } } - protected String createKnnIndexMappingWithLuceneField(final String fieldName, int dimension) throws IOException { - return Strings.toString( - XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", Integer.toString(dimension)) - .startObject("method") - .field("name", "hnsw") - .field("engine", "lucene") - .field("space_type", "l2") - .endObject() - .endObject() - .endObject() - .endObject() - ); - } - private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); @@ -81,6 +71,16 @@ private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQuery request.addParameter("search_type", "query_then_fetch"); request.setJsonEntity(Strings.toString(builder)); - expectThrows(ResponseException.class, () -> client().performRequest(request)); + Exception exception = expectThrows(ResponseException.class, () -> client().performRequest(request)); + // assert for two possible exception messages, fist one can come from current version in case serialized request is coming from + // lower version, + // second exception is vise versa, when lower version node receives request with filter field from higher version + MatcherAssert.assertThat( + exception.getMessage(), + anyOf( + containsString("filter field is supported from version"), + containsString("[knn] unknown token [START_OBJECT] after [filter]") + ) + ); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNClusterContext.java b/src/main/java/org/opensearch/knn/index/KNNClusterUtil.java similarity index 89% rename from src/main/java/org/opensearch/knn/index/KNNClusterContext.java rename to src/main/java/org/opensearch/knn/index/KNNClusterUtil.java index 938ed8302..63a49f095 100644 --- a/src/main/java/org/opensearch/knn/index/KNNClusterContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNClusterUtil.java @@ -16,18 +16,18 @@ */ @NoArgsConstructor(access = AccessLevel.PRIVATE) @Log4j2 -public class KNNClusterContext { +public class KNNClusterUtil { private ClusterService clusterService; - private static KNNClusterContext instance; + private static KNNClusterUtil instance; /** * Return instance of the cluster context, must be initialized first for proper usage * @return instance of cluster context */ - public static synchronized KNNClusterContext instance() { + public static synchronized KNNClusterUtil instance() { if (instance == null) { - instance = new KNNClusterContext(); + instance = new KNNClusterUtil(); } return instance; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index a27baca3f..08ed23d05 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -9,7 +9,7 @@ import org.opensearch.Version; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.knn.index.KNNClusterContext; +import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; @@ -322,6 +322,6 @@ public String getWriteableName() { } private static boolean isClusterOnOrAfterMinRequiredVersion() { - return KNNClusterContext.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); + return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index c72198c7d..670294802 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -11,7 +11,7 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; -import org.opensearch.knn.index.KNNClusterContext; +import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -180,7 +180,7 @@ public Collection createComponents( NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); KNNSettings.state().initialize(client, clusterService); - KNNClusterContext.instance().initialize(clusterService); + KNNClusterUtil.instance().initialize(clusterService); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java b/src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java similarity index 60% rename from src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java rename to src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java index 5c8ed970e..0e00a7f75 100644 --- a/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java @@ -13,15 +13,15 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; -public class KNNClusterContextTests extends KNNTestCase { +public class KNNClusterUtilTests extends KNNTestCase { public void testSingleNodeCluster() { ClusterService clusterService = mockClusterService(Version.V_2_4_0); - final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); - knnClusterContext.initialize(clusterService); + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); - final Version minVersion = knnClusterContext.getClusterMinVersion(); + final Version minVersion = knnClusterUtil.getClusterMinVersion(); assertTrue(Version.V_2_4_0.equals(minVersion)); } @@ -29,10 +29,10 @@ public void testSingleNodeCluster() { public void testMultipleNodesCluster() { ClusterService clusterService = mockClusterService(Version.V_2_3_0); - final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); - knnClusterContext.initialize(clusterService); + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); - final Version minVersion = knnClusterContext.getClusterMinVersion(); + final Version minVersion = knnClusterUtil.getClusterMinVersion(); assertTrue(Version.V_2_3_0.equals(minVersion)); } @@ -41,10 +41,10 @@ public void testWhenErrorOnClusterStateDiscover() { ClusterService clusterService = mock(ClusterService.class); when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); - final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); - knnClusterContext.initialize(clusterService); + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); - final Version minVersion = knnClusterContext.getClusterMinVersion(); + final Version minVersion = knnClusterUtil.getClusterMinVersion(); assertTrue(Version.CURRENT.equals(minVersion)); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 3be622da0..e3376dda9 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -26,7 +26,7 @@ import org.opensearch.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.knn.index.KNNClusterContext; +import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -105,8 +105,8 @@ public void testFromXcontent() throws Exception { public void testFromXcontent_WithFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); - final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); - knnClusterContext.initialize(clusterService); + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -127,8 +127,8 @@ public void testFromXcontent_WithFilter() throws Exception { public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception { final ClusterService clusterService = mockClusterService(Version.V_2_3_0); - final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); - knnClusterContext.initialize(clusterService); + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -268,8 +268,8 @@ private void assertSerialization(final Version version, final Optional Date: Mon, 24 Oct 2022 14:43:18 -0700 Subject: [PATCH 8/8] Adding code comments Signed-off-by: Martin Gaievski --- .../org/opensearch/knn/index/query/KNNQueryBuilder.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 08ed23d05..ebf721304 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -112,6 +112,8 @@ public KNNQueryBuilder(StreamInput in) throws IOException { fieldName = in.readString(); vector = in.readFloatArray(); k = in.readInt(); + // We're checking if all cluster nodes has at least that version or higher. This check is required + // to avoid issues with cluster upgrade if (isClusterOnOrAfterMinRequiredVersion()) { filter = in.readOptionalNamedWriteable(QueryBuilder.class); } @@ -159,6 +161,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment(); + // Query filters are supported starting from a certain k-NN version only, exact version is defined by + // MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable. + // Here we're checking if all cluster nodes has at least that version or higher. This check is required + // to avoid issues with rolling cluster upgrade if (isClusterOnOrAfterMinRequiredVersion()) { filter = parseInnerQueryBuilder(parser); } else { @@ -204,6 +210,8 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeFloatArray(vector); out.writeInt(k); + // We're checking if all cluster nodes has at least that version or higher. This check is required + // to avoid issues with cluster upgrade if (isClusterOnOrAfterMinRequiredVersion()) { out.writeOptionalNamedWriteable(filter); }