diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 80ec9164f..6b5ecf615 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -6,16 +6,29 @@ package org.opensearch.knn.index; import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.rest.RestStatus; +import org.opensearch.script.Script; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; @@ -32,6 +45,7 @@ public class VectorDataTypeIT extends KNNRestTestCase { private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final int EF_CONSTRUCTION = 128; private static final int M = 16; + private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder(); @After @SneakyThrows @@ -176,6 +190,202 @@ public void testByteVectorDataTypeWithNmslibEngine() { ); } + @SneakyThrows + public void testByteVectorDataTypeWithLegacyFieldMapperKnnIndexSetting() { + // Create an index with byte vector data_type and index.knn as true without setting KnnMethodContext, + // which should throw an exception because the LegacyFieldMapper will use NMSLIB engine and byte data_type + // is not supported for NMSLIB engine. + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 2) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + LUCENE_NAME + ) + ) + ); + + } + + public void testDocValuesWithByteVectorDataTypeLuceneEngine() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testDocValuesWithFloatVectorDataTypeLuceneEngine() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2ScriptScoreWithByteVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2ScriptScoreWithFloatVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + Float[] queryVector = { 1.0f, 1.0f }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2PainlessScriptingWithByteVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + String source = String.format("1/(1 + l2Squared([1, 1], doc['%s']))", FIELD_NAME); + Request request = constructScriptScoreContextSearchRequest( + INDEX_NAME, + MATCH_ALL_QUERY_BUILDER, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + 4 + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); + Request request = constructScriptScoreContextSearchRequest( + INDEX_NAME, + MATCH_ALL_QUERY_BUILDER, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + 4 + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testKNNScriptScoreWithInvalidVectorDataType() { + // Set an invalid value for data_type field while creating the index for script scoring which should throw an exception + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndexMappingForScripting(2, "invalid_data_type")); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ) + ); + } + + public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception { + // Create an index with byte vector data_type, add docs and run a scoring script query with decimal values + // which should throw exception + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + + Byte[] f1 = { 6, 6 }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Byte[] f2 = { 2, 2 }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + // Construct Search Request with query vector having decimal values + Float[] queryVector = { 10.67f, 19.78f }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + } + + @SneakyThrows + private void ingestL2ByteTestData() { + Byte[] b1 = { 6, 6 }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, b1); + + Byte[] b2 = { 2, 2 }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, b2); + + Byte[] b3 = { 4, 4 }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, b3); + + Byte[] b4 = { 3, 3 }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, b4); + } + + @SneakyThrows + private void ingestL2FloatTestData() { + Float[] f1 = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Float[] f2 = { 2.0f, 2.0f }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + Float[] f3 = { 4.0f, 4.0f }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); + + Float[] f4 = { 3.0f, 3.0f }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); + } + private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.NMSLIB.getName()); } @@ -209,4 +419,51 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac String mapping = Strings.toString(builder); createKnnIndex(INDEX_NAME, mapping); } + + private void createKnnIndexMappingForScripting(int dimension, String vectorDataType) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, Settings.EMPTY, mapping); + } + + @SneakyThrows + private Request createScriptQueryRequest(Byte[] queryVector, String spaceType, QueryBuilder qb) { + Map params = new HashMap<>(); + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType); + return constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + } + + @SneakyThrows + private Request createScriptQueryRequest(Float[] queryVector, String spaceType, QueryBuilder qb) { + Map params = new HashMap<>(); + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType); + return constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + } + + @SneakyThrows + private void validateL2SearchResults(Response response) { + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(4, results.size()); + + String[] expectedDocIDs = { "2", "4", "3", "1" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java new file mode 100644 index 000000000..4423c85d8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.junit.Assert; +import org.opensearch.knn.KNNTestCase; + +import java.io.IOException; + +public class VectorDataTypeTests extends KNNTestCase { + + private static final String MOCK_FLOAT_INDEX_FIELD_NAME = "test-float-index-field-name"; + private static final String MOCK_BYTE_INDEX_FIELD_NAME = "test-byte-index-field-name"; + private static final float[] SAMPLE_FLOAT_VECTOR_DATA = new float[] { 10.0f, 25.0f }; + private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 10, 25 }; + private Directory directory; + private DirectoryReader reader; + + @SneakyThrows + public void testGetDocValuesWithFloatVectorDataType() { + KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testGetDocValuesWithByteVectorDataType() { + KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + + reader.close(); + directory.close(); + } + + @SneakyThrows + private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { + directory = newDirectory(); + createKNNFloatVectorDocument(directory); + reader = DirectoryReader.open(directory); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + return new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), + VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + } + + @SneakyThrows + private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { + directory = newDirectory(); + createKNNByteVectorDocument(directory); + reader = DirectoryReader.open(directory); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + return new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), + VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, + VectorDataType.BYTE + ); + } + + private void createKNNFloatVectorDocument(Directory directory) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + knnDocument.add( + new BinaryDocValuesField( + MOCK_FLOAT_INDEX_FIELD_NAME, + new VectorField(MOCK_FLOAT_INDEX_FIELD_NAME, SAMPLE_FLOAT_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } + + private void createKNNByteVectorDocument(Directory directory) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + knnDocument.add( + new BinaryDocValuesField( + MOCK_BYTE_INDEX_FIELD_NAME, + new VectorField(MOCK_BYTE_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } +}