From 45282e056a9eb875aa53b08f7ee5865b277e651b Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Tue, 9 Jan 2024 12:59:41 -0800 Subject: [PATCH] Throw proper exception to invalid k-NN query (#1380) * Throw proper exception to invalid k-NN query Signed-off-by: Junqiu Lei * Move PR to enhancement in CHANGELOG.md Signed-off-by: Junqiu Lei * Resolve PR feedback Signed-off-by: Junqiu Lei * Resolve PR feedback Signed-off-by: Junqiu Lei * Revert IT tests Signed-off-by: Junqiu Lei --------- Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + .../knn/index/query/KNNQueryBuilder.java | 6 ++ .../knn/index/VectorDataTypeIT.java | 51 +++++++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 65 +++++++++++++++++++ 4 files changed, 123 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aafc2e585..a9c34eaf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Increase Lucene max dimension limit to 16,000 [#1346](https://github.com/opensearch-project/k-NN/pull/1346) * Tuned default values for ef_search and ef_construction for better indexing and search performance for vector search [#1353](https://github.com/opensearch-project/k-NN/pull/1353) * Enabled Filtering on Nested Vector fields with top level filters [#1372](https://github.com/opensearch-project/k-NN/pull/1372) +* Throw proper exception to invalid k-NN query [#1380](https://github.com/opensearch-project/k-NN/pull/1380) ### Bug Fixes * Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305) * Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 6caf2ed9b..096c2e30b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -100,8 +100,14 @@ public static void initialize(ModelDao modelDao) { } private static float[] ObjectsToFloats(List objs) { + if (Objects.isNull(objs) || objs.isEmpty()) { + throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME)); + } float[] vec = new float[objs.size()]; for (int i = 0; i < objs.size(); i++) { + if ((objs.get(i) instanceof Number) == false) { + throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME)); + } vec[i] = ((Number) objs.get(i)).floatValue(); } return vec; diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 2d3f53580..e55a3be42 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -24,6 +24,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.script.Script; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -425,6 +426,56 @@ public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception { ); } + @SneakyThrows + public void testSearchWithInvalidSearchVectorType() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME)); + List invalidTypeQueryVector = new ArrayList<>(); + invalidTypeQueryVector.add(1.5); + invalidTypeQueryVector.add(2.5); + invalidTypeQueryVector.add("a"); + invalidTypeQueryVector.add(null); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", invalidTypeQueryVector) + .field("k", 4) + .endObject() + .endObject() + .endObject() + .endObject(); + request.setJsonEntity(builder.toString()); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertEquals(400, ex.getResponse().getStatusLine().getStatusCode()); + assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); + } + + @SneakyThrows + public void testSearchWithMissingQueryVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME)); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("k", 4) + .endObject() + .endObject() + .endObject() + .endObject(); + request.setJsonEntity(builder.toString()); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertEquals(400, ex.getResponse().getStatusLine().getStatusCode()); + assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty")); + } + @SneakyThrows private void ingestL2ByteTestData() { Byte[] b1 = { 6, 6 }; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index e540725fc..a981c684e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -39,6 +39,7 @@ import org.opensearch.plugins.SearchPlugin; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -127,6 +128,70 @@ public void testFromXcontent_WithFilter() throws Exception { actualBuilder.equals(knnQueryBuilder); } + public void testFromXContent_invalidQueryVectorType() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + List invalidTypeQueryVector = new ArrayList<>(); + invalidTypeQueryVector.add(1.5); + invalidTypeQueryVector.add(2.5); + invalidTypeQueryVector.add("a"); + invalidTypeQueryVector.add(null); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.fromXContent(contentParser) + ); + assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); + } + + public void testFromXContent_missingQueryVector() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + // Test without vector field + XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder(); + builderWithoutVectorField.startObject(); + builderWithoutVectorField.startObject(FIELD_NAME); + builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builderWithoutVectorField.endObject(); + builderWithoutVectorField.endObject(); + XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField); + contentParserWithoutVectorField.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.fromXContent(contentParserWithoutVectorField) + ); + assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty")); + + // Test empty vector field + List emptyQueryVector = new ArrayList<>(); + XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder(); + builderWithEmptyVector.startObject(); + builderWithEmptyVector.startObject(FIELD_NAME); + builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector); + builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K); + builderWithEmptyVector.endObject(); + builderWithEmptyVector.endObject(); + XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector); + contentParserWithEmptyVector.nextToken(); + exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParserWithEmptyVector)); + assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty")); + } + @Override protected NamedXContentRegistry xContentRegistry() { List list = ClusterModule.getNamedXWriteables();