From 77db3ab8fe6a2a2fcb17bc14663c57398c89c689 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Tue, 11 Jul 2023 15:51:50 -0500 Subject: [PATCH] Add Querying Support to Lucene Byte Sized Vector (#956) * Add Querying Support to Lucene Byte Sized Vector Signed-off-by: Naveen Tatikonda * Add CHANGELOG Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda --- .../knn/index/query/KNNQueryBuilder.java | 21 ++- .../knn/index/query/KNNQueryFactory.java | 68 +++++++++- .../opensearch/knn/index/LuceneEngineIT.java | 111 ++++++++++------ .../knn/index/VectorDataTypeIT.java | 123 ++++++++++++++++++ .../knn/index/codec/KNNCodecTestCase.java | 19 ++- .../knn/index/query/KNNQueryBuilderTests.java | 4 + .../knn/index/query/KNNQueryFactoryTests.java | 20 ++- 7 files changed, 315 insertions(+), 51 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 16277720d..0b4730279 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -10,6 +10,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -30,6 +31,8 @@ import java.util.List; import java.util.Objects; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; + /** * Helper class to build the KNN query */ @@ -235,6 +238,7 @@ protected Query doToQuery(QueryShardContext context) { int fieldDimension = knnVectorFieldType.getDimension(); KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); KNNEngine knnEngine = KNNEngine.DEFAULT; + VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); if (fieldDimension == -1) { // If dimension is not set, the field uses a model and the information needs to be retrieved from there @@ -252,9 +256,18 @@ protected Query doToQuery(QueryShardContext context) { ); } + byte[] byteVector = new byte[0]; + if (VectorDataType.BYTE == vectorDataType) { + byteVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + validateByteVectorValue(vector[i]); + byteVector[i] = (byte) vector[i]; + } + } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) - && filter != null - && !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) { + && filter != null + && !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) { throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); } @@ -263,7 +276,9 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(this.vector) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vectorDataType(vectorDataType) .k(this.k) .filter(this.filter) .context(context) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 20c456c4a..65c15499d 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -11,15 +11,21 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Locale; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; + /** * Creates the Lucene k-NN queries */ @@ -36,12 +42,20 @@ public class KNNQueryFactory { * @param k the number of nearest neighbors to return * @return Lucene Query */ - public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) { + public static Query create( + KNNEngine knnEngine, + String indexName, + String fieldName, + float[] vector, + int k, + VectorDataType vectorDataType + ) { final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(indexName) .fieldName(fieldName) .vector(vector) + .vectorDataType(vectorDataType) .k(k) .build(); return create(createQueryRequest); @@ -59,6 +73,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { final String fieldName = createQueryRequest.getFieldName(); final int k = createQueryRequest.getK(); final float[] vector = createQueryRequest.getVector(); + final byte[] byteVector = createQueryRequest.getByteVector(); + final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { @@ -77,14 +93,54 @@ public static Query create(CreateQueryRequest createQueryRequest) { return new KNNQuery(fieldName, vector, k, indexName); } + if (VectorDataType.BYTE == vectorDataType) { + return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery); + } else if (VectorDataType.FLOAT == vectorDataType) { + return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery); + } else { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ); + } + } + + private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) { + if (filterQuery != null) { + log.debug( + String.format( + Locale.ROOT, + "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", + indexName, + fieldName, + k + ) + ); + return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery); + } + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KnnByteVectorQuery(fieldName, byteVector, k); + } + + private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) { if (filterQuery != null) { log.debug( - String.format("Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k) + String.format( + Locale.ROOT, + "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", + indexName, + fieldName, + k + ) ); - return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery); } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnFloatVectorQuery(fieldName, vector, k); + return new KnnFloatVectorQuery(fieldName, floatVector, k); } private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { @@ -126,6 +182,10 @@ static class CreateQueryRequest { @Getter private float[] vector; @Getter + private byte[] byteVector; + @Getter + private VectorDataType vectorDataType; + @Getter private int k; // can be null in cases filter not passed with the knn query private QueryBuilder filter; diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index cb84698db..594daeaeb 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Floats; +import lombok.SneakyThrows; import org.apache.commons.lang.math.RandomUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.lucene.index.VectorSimilarityFunction; @@ -34,8 +35,10 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class LuceneEngineIT extends KNNRestTestCase { @@ -110,7 +113,7 @@ public void testQuery_innerProduct_notSupported() throws Exception { public void testQuery_invalidVectorDimensionInQuery() throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } @@ -127,7 +130,7 @@ public void testQuery_documentsMissingField() throws Exception { SpaceType spaceType = SpaceType.L2; - createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType); + createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } @@ -224,35 +227,35 @@ public void testAddDoc() throws Exception { Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); - refreshAllIndices(); + refreshIndex(INDEX_NAME); assertEquals(1, getDocCount(INDEX_NAME)); } public void testUpdateDoc() throws Exception { - createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT); Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); Float[] updatedVector = { 8.0f, 8.0f }; updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); - refreshAllIndices(); + refreshIndex(INDEX_NAME); assertEquals(1, getDocCount(INDEX_NAME)); } public void testDeleteDoc() throws Exception { - createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT); Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); deleteKnnDoc(INDEX_NAME, DOC_ID); - refreshAllIndices(); + refreshIndex(INDEX_NAME); assertEquals(0, getDocCount(INDEX_NAME)); } - public void testQueryWithFilter() throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + public void testQueryWithFilterUsingFloatVectorDataType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); addKnnDocWithAttributes( DOC_ID, @@ -262,39 +265,28 @@ public void testQueryWithFilter() throws Exception { addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); - refreshAllIndices(); + refreshIndex(INDEX_NAME); final float[] searchVector = { 6.0f, 6.0f, 4.1f }; - int kGreaterThanFilterResult = 5; - List expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3); - final Response response = searchKNNIndex( - INDEX_NAME, - new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), - kGreaterThanFilterResult - ); - final String responseBody = EntityUtils.toString(response.getEntity()); - final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + List expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3); + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); + } - assertEquals(expectedDocIds.size(), knnResults.size()); - assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + @SneakyThrows + public void testQueryWithFilterUsingByteVectorDataType() { + createKnnIndexMappingWithLuceneEngine(3, SpaceType.L2, VectorDataType.BYTE); - int kLimitsFilterResult = 1; - List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); - final Response responseKLimitsFilterResult = searchKNNIndex( - INDEX_NAME, - new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), - kLimitsFilterResult - ); - final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); - final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.0f, 3.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.0f, 2.0f, 4.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.0f, 5.0f, 7.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); - assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); - assertTrue( - knnResultsKLimitsFilterResult.stream() - .map(KNNResult::getDocId) - .collect(Collectors.toList()) - .containsAll(expectedDocIdsKLimitsFilterResult) - ); + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 6.0f, 4.0f }; + List expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3); + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); } public void testQuery_filterWithNonLuceneEngine() throws Exception { @@ -337,7 +329,7 @@ public void testQuery_filterWithNonLuceneEngine() throws Exception { } public void testIndexReopening() throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); @@ -358,13 +350,14 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } - private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject(PROPERTIES_FIELD_NAME) .startObject(FIELD_NAME) .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION_FIELD_NAME, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName()) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) @@ -384,7 +377,7 @@ private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spac private void baseQueryTest(SpaceType spaceType) throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType); + createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } @@ -419,4 +412,42 @@ private List queryResults(final float[] searchVector, final int k) thro assertNotNull(knnResults); return knnResults.stream().map(KNNResult::getVector).collect(Collectors.toUnmodifiableList()); } + + @SneakyThrows + private void validateQueryResultsWithFilters( + float[] searchVector, + int kGreaterThanFilterResult, + int kLimitsFilterResult, + List expectedDocIdsKGreaterThanFilterResult, + List expectedDocIdsKLimitsFilterResult + ) { + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIdsKGreaterThanFilterResult.size(), knnResults.size()); + assertTrue( + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIdsKGreaterThanFilterResult) + ); + + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 80ec9164f..711160cf9 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -6,15 +6,20 @@ package org.opensearch.knn.index; import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.After; +import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; +import java.util.List; import java.util.Locale; import static org.opensearch.knn.common.KNNConstants.DIMENSION; @@ -77,6 +82,73 @@ public void testDeleteDocWithByteVector() { assertEquals(0, getDocCount(INDEX_NAME)); } + @SneakyThrows + public void testSearchWithByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, convertByteToFloatArray(queryVector), 4), 4); + + validateL2SearchResults(response); + } + + @SneakyThrows + public void testSearchWithInvalidByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + // Validate search with floats instead of byte vectors + float[] queryVector = { -10.76f, 15.89f }; + ResponseException ex = expectThrows( + ResponseException.class, + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector, 4), 4) + ); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + + // validate search with search vectors outside of byte range + float[] queryVector1 = { -1000.0f, 200.0f }; + ResponseException ex1 = expectThrows( + ResponseException.class, + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector1, 4), 4) + ); + + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ) + ); + } + + @SneakyThrows + public void testSearchWithFloatVectorDataType() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + float[] queryVector = { 1.0f, 1.0f }; + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector, 4), 4); + + validateL2SearchResults(response); + } + // Set an invalid value for data_type field while creating the index which should throw an exception public void testInvalidVectorDataType() { String vectorDataType = "invalidVectorType"; @@ -176,6 +248,36 @@ public void testByteVectorDataTypeWithNmslibEngine() { ); } + @SneakyThrows + private void ingestL2ByteTestData() { + Byte[] b1 = { 6, 6 }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, b1); + + Byte[] b2 = { 2, 2 }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, b2); + + Byte[] b3 = { 4, 4 }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, b3); + + Byte[] b4 = { 3, 3 }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, b4); + } + + @SneakyThrows + private void ingestL2FloatTestData() { + Float[] f1 = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Float[] f2 = { 2.0f, 2.0f }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + Float[] f3 = { 4.0f, 4.0f }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); + + Float[] f4 = { 3.0f, 3.0f }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); + } + private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.NMSLIB.getName()); } @@ -209,4 +311,25 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac String mapping = Strings.toString(builder); createKnnIndex(INDEX_NAME, mapping); } + + @SneakyThrows + private void validateL2SearchResults(Response response) { + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(4, results.size()); + + String[] expectedDocIDs = { "2", "4", "3", "1" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + } + + private float[] convertByteToFloatArray(Byte[] arr) { + float[] floatArray = new float[arr.length]; + for (int i = 0; i < arr.length; i++) { + floatArray[i] = arr[i]; + } + return floatArray; + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 6bfde31bb..6c7631216 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -69,6 +69,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.Version.CURRENT; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -338,7 +339,14 @@ public void testKnnVectorIndex( verify(perFieldKnnVectorsFormatSpy, atLeastOnce()).getKnnVectorsFormatForField(eq(FIELD_NAME_ONE)); IndexSearcher searcher = new IndexSearcher(reader); - Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_ONE, new float[] { 1.0f, 0.0f, 0.0f }, 1); + Query query = KNNQueryFactory.create( + KNNEngine.LUCENE, + "dummy", + FIELD_NAME_ONE, + new float[] { 1.0f, 0.0f, 0.0f }, + 1, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertEquals(1, searcher.count(query)); @@ -365,7 +373,14 @@ public void testKnnVectorIndex( verify(perFieldKnnVectorsFormatSpy, atLeastOnce()).getKnnVectorsFormatForField(eq(FIELD_NAME_TWO)); IndexSearcher searcher1 = new IndexSearcher(reader1); - Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_TWO, new float[] { 1.0f, 0.0f }, 1); + Query query1 = KNNQueryFactory.create( + KNNEngine.LUCENE, + "dummy", + FIELD_NAME_TWO, + new float[] { 1.0f, 0.0f }, + 1, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertEquals(1, searcher1.count(query1)); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 978011610..c98f74e62 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -30,6 +30,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -153,6 +154,7 @@ public void testDoToQuery_Normal() throws Exception { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); @@ -168,6 +170,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -219,6 +222,7 @@ public void testDoToQuery_FromModel() { // Dimension is -1. In this case, model metadata will need to provide dimension when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; when(mockKNNVectorField.getModelId()).thenReturn(modelId); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 674d1be39..4dccfd087 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -24,6 +24,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; public class KNNQueryFactoryTests extends KNNTestCase { private static final String FILTER_FILED_NAME = "foo"; @@ -38,7 +39,14 @@ public class KNNQueryFactoryTests extends KNNTestCase { public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); + Query query = KNNQueryFactory.create( + knnEngine, + testIndexName, + testFieldName, + testQueryVector, + testK, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); @@ -53,7 +61,14 @@ public void testCreateLuceneDefaultQuery() { .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) .collect(Collectors.toList()); for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { - Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); + Query query = KNNQueryFactory.create( + knnEngine, + testIndexName, + testFieldName, + testQueryVector, + testK, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } } @@ -71,6 +86,7 @@ public void testCreateLuceneQueryWithFilter() { .indexName(testIndexName) .fieldName(testFieldName) .vector(testQueryVector) + .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .k(testK) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER)