diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java new file mode 100644 index 0000000000..3ea611cbfe --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.bwc; + +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; + +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; + +import java.io.IOException; + +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; + +/** + * Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine + */ +public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 50; + private static final int K = 10; + private static final int NUM_DOCS = 100; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100"); + + public void testLuceneFiltering() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0]; + switch (getClusterType()) { + case OLD: + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMappingWithLuceneField(TEST_FIELD, DIMENSIONS)); + bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS); + validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + break; + case MIXED: + validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + break; + case UPGRADED: + searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + deleteKNNIndex(testIndex); + break; + } + } + + protected String createKnnIndexMappingWithLuceneField(final String fieldName, int dimension) throws IOException { + return Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", Integer.toString(dimension)) + .startObject("method") + .field("name", "hnsw") + .field("engine", "lucene") + .field("space_type", "l2") + .endObject() + .endObject() + .endObject() + .endObject() + ); + } + + private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject().endObject(); + + Request request = new Request("POST", "/" + index + "/_search"); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + request.setJsonEntity(Strings.toString(builder)); + + expectThrows(ResponseException.class, () -> client().performRequest(request)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/KNNClusterContext.java b/src/main/java/org/opensearch/knn/index/KNNClusterContext.java new file mode 100644 index 0000000000..a98cc8bea8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/KNNClusterContext.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import com.carrotsearch.hppc.cursors.ObjectCursor; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.ImmutableOpenMap; + +/** + * Class abstracts information related to underlying OpenSearch cluster + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@Log4j2 +public class KNNClusterContext { + + private ClusterService clusterService; + private static KNNClusterContext instance; + + /** + * Return instance of the cluster context, must be initialized first for proper usage + * @return instance of cluster context + */ + public static synchronized KNNClusterContext instance() { + if (instance == null) { + instance = new KNNClusterContext(); + } + return instance; + } + + /** + * Initializes instance of cluster context by injecting dependencies + * @param clusterService + */ + public void initialize(final ClusterService clusterService) { + this.clusterService = clusterService; + } + + /** + * Return minimal OpenSearch version based on all nodes currently discoverable in the cluster + * @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version + */ + public Version getClusterMinVersion() { + Version minVersion = Version.CURRENT; + ImmutableOpenMap clusterDiscoveryNodes = ImmutableOpenMap.of(); + log.debug("Reading cluster min version"); + try { + clusterDiscoveryNodes = this.clusterService.state().getNodes().getNodes(); + } catch (Exception exception) { + log.error("Cannot get cluster nodes", exception); + } + for (final ObjectCursor discoveryNodeCursor : clusterDiscoveryNodes.values()) { + final Version nodeVersion = discoveryNodeCursor.value.getVersion(); + if (nodeVersion.before(minVersion)) { + minVersion = nodeVersion; + log.debug("Update cluster min version to {} based on node {}", nodeVersion, discoveryNodeCursor.value.toString()); + } + } + log.debug("Return cluster min version {}", minVersion); + return minVersion; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index aeefdbff4c..b94b404007 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -6,8 +6,10 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.KNNClusterContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; @@ -52,6 +54,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final float[] vector; private int k = 0; private QueryBuilder filter; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_3_0_0; /** * Constructs a new knn query @@ -109,8 +112,11 @@ public KNNQueryBuilder(StreamInput in) throws IOException { fieldName = in.readString(); vector = in.readFloatArray(); k = in.readInt(); + if (isClusterOnOrAfterMinRequiredVersion()) { + filter = in.readOptionalNamedWriteable(QueryBuilder.class); + } } catch (IOException ex) { - throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex); + throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } } @@ -152,7 +158,23 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); - filter = parseInnerQueryBuilder(parser); + if (isClusterOnOrAfterMinRequiredVersion()) { + filter = parseInnerQueryBuilder(parser); + } else { + log.debug( + String.format( + "This version of k-NN doesn't support [filter] field, minimal required version is [%s]", + MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + ) + ); + throw new IllegalArgumentException( + String.format( + "%s field is supported from version %s", + FILTER_FIELD.getPreferredName(), + MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + ) + ); + } } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } @@ -181,6 +203,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeFloatArray(vector); out.writeInt(k); + if (isClusterOnOrAfterMinRequiredVersion()) { + out.writeOptionalNamedWriteable(filter); + } } /** @@ -294,4 +319,8 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static boolean isClusterOnOrAfterMinRequiredVersion() { + return KNNClusterContext.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index c2564f1795..c72198c7d7 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -11,6 +11,7 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; +import org.opensearch.knn.index.KNNClusterContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -179,6 +180,7 @@ public Collection createComponents( NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); KNNSettings.state().initialize(client, clusterService); + KNNClusterContext.instance().initialize(clusterService); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java b/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java new file mode 100644 index 0000000000..55e6bbde29 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.KNNTestCase; + +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; + +public class KNNClusterContextTests extends KNNTestCase { + + public void testSingleNodeCluster() { + ClusterService clusterService = mockClusterService(List.of(Version.V_2_4_0)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + final Version minVersion = knnClusterContext.getClusterMinVersion(); + + assertTrue(Version.V_2_4_0.equals(minVersion)); + } + + public void testMultipleNodesCluster() { + ClusterService clusterService = mockClusterService(List.of(Version.V_3_0_0, Version.V_2_3_0, Version.V_3_0_0)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + final Version minVersion = knnClusterContext.getClusterMinVersion(); + + assertTrue(Version.V_2_3_0.equals(minVersion)); + } + + public void testWhenErrorOnClusterStateDiscover() { + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + final Version minVersion = knnClusterContext.getClusterMinVersion(); + + assertTrue(Version.CURRENT.equals(minVersion)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java new file mode 100644 index 0000000000..f58584898b --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.ImmutableOpenMap; + +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; + +/** + * Collection of util methods required for testing and related to OpenSearch cluster setup and functionality + */ +public class KNNClusterTestUtils { + + /** + * Create new mock for ClusterService + * @param versions list of versions for cluster nodes + * @return + */ + public static ClusterService mockClusterService(final List versions) { + ClusterService clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + when(clusterService.state()).thenReturn(clusterState); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); + for (Version version : versions) { + DiscoveryNode clusterNode = mock(DiscoveryNode.class); + when(clusterNode.getVersion()).thenReturn(version); + builder.put(randomAlphaOfLength(10), clusterNode); + } + ImmutableOpenMap mapOfNodes = builder.build(); + when(discoveryNodes.getNodes()).thenReturn(mapOfNodes); + + return clusterService; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index d95e1dc3cd..75a5243f41 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -21,6 +21,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; @@ -178,6 +179,7 @@ public FieldInfo build() { pointIndexDimensionCount, pointNumBytes, vectorDimension, + VectorEncoding.BYTE, vectorSimilarityFunction, softDeletes ); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 4ebcf9ec48..435987f7e4 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -8,7 +8,13 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -20,6 +26,7 @@ import org.opensearch.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.KNNClusterContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -31,31 +38,38 @@ import java.io.IOException; import java.util.List; +import java.util.Optional; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; public class KNNQueryBuilderTests extends KNNTestCase { + private static final String FIELD_NAME = "myvector"; + private static final int K = 1; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + public void testInvalidK() { float[] queryVector = { 1.0f, 1.0f }; /** * -ve k */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, -1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, -K)); /** * zero k */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 0)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, 0)); /** * k > KNNQueryBuilder.K_MAX */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, KNNQueryBuilder.K_MAX + 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, KNNQueryBuilder.K_MAX + K)); } public void testEmptyVector() { @@ -63,18 +77,18 @@ public void testEmptyVector() { * null query vector */ float[] queryVector = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, K)); /** * empty query vector */ float[] queryVector1 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector1, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K)); } public void testFromXcontent() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -89,8 +103,13 @@ public void testFromXcontent() throws Exception { } public void testFromXcontent_WithFilter() throws Exception { + final ClusterService clusterService = mockClusterService(List.of(Version.CURRENT)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -105,6 +124,28 @@ public void testFromXcontent_WithFilter() throws Exception { actualBuilder.equals(knnQueryBuilder); } + public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception { + final ClusterService clusterService = mockClusterService(List.of(Version.V_2_3_0)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + final XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + + expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser)); + } + @Override protected NamedXContentRegistry xContentRegistry() { List list = ClusterModule.getNamedXWriteables(); @@ -118,9 +159,17 @@ protected NamedXContentRegistry xContentRegistry() { return registry; } + @Override + protected NamedWriteableRegistry writableRegistry() { + final List entries = ClusterModule.getNamedWriteables(); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KNNQueryBuilder.NAME, KNNQueryBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + return new NamedWriteableRegistry(entries); + } + public void testDoToQuery_Normal() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -135,7 +184,7 @@ public void testDoToQuery_Normal() throws Exception { public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -152,14 +201,14 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension - when(mockKNNVectorField.getDimension()).thenReturn(-1); + when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; when(mockKNNVectorField.getModelId()).thenReturn(modelId); @@ -181,7 +230,7 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_InvalidDimensions() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -189,13 +238,13 @@ public void testDoToQuery_InvalidDimensions() { when(mockKNNVectorField.getDimension()).thenReturn(400); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - when(mockKNNVectorField.getDimension()).thenReturn(1); + when(mockKNNVectorField.getDimension()).thenReturn(K); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } public void testDoToQuery_InvalidFieldType() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); NumberFieldMapper.NumberFieldType mockNumberField = mock(NumberFieldMapper.NumberFieldType.class); @@ -203,4 +252,45 @@ public void testDoToQuery_InvalidFieldType() throws IOException { when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockNumberField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + + public void testSerialization() throws Exception { + assertSerialization(Version.CURRENT, Optional.empty()); + + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY)); + + assertSerialization(Version.V_2_3_0, Optional.empty()); + } + + private void assertSerialization(final Version version, final Optional queryBuilderOptional) throws Exception { + final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + + final ClusterService clusterService = mockClusterService(List.of(version)); + + final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); + knnClusterContext.initialize(clusterService); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + output.writeNamedWriteable(knnQueryBuilder); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + in.setVersion(Version.CURRENT); + final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + + assertNotNull(deserializedQuery); + assertTrue(deserializedQuery instanceof KNNQueryBuilder); + final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery; + assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); + assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); + assertEquals(K, deserializedKnnQueryBuilder.getK()); + if (queryBuilderOptional.isPresent()) { + assertNotNull(deserializedKnnQueryBuilder.getFilter()); + assertEquals(queryBuilderOptional.get(), deserializedKnnQueryBuilder.getFilter()); + } else { + assertNull(deserializedKnnQueryBuilder.getFilter()); + } + } + } + } }