diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d9d35ef03..728871ddbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features ### Enhancements +* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index f4caa4f203..85f037c0f1 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,9 +5,10 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -39,10 +40,29 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); + if (fieldInfo == null) { + return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); + } + + DocIdSetIterator values; + if (fieldInfo.hasVectorValues()) { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + values = reader.getFloatVectorValues(fieldName); + break; + case BYTE: + values = reader.getByteVectorValues(fieldName); + break; + default: + throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); + } + } else { + values = DocValues.getBinary(reader, fieldName); + } + return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); + throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 9f7d522053..c733c534e3 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -5,18 +5,22 @@ package org.opensearch.knn.index; +import java.io.IOException; +import java.util.Objects; +import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import java.io.IOException; - -@RequiredArgsConstructor -public final class KNNVectorScriptDocValues extends ScriptDocValues { +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { - private final BinaryDocValues binaryDocValues; + private final DocIdSetIterator vectorValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; @@ -24,11 +28,7 @@ public final class KNNVectorScriptDocValues extends ScriptDocValues { @Override public void setNextDocId(int docId) throws IOException { - if (binaryDocValues.advanceExact(docId)) { - docExists = true; - return; - } - docExists = false; + docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId; } public float[] getValue() { @@ -43,12 +43,14 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); + return doGetValue(); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } } + protected abstract float[] doGetValue() throws IOException; + @Override public int size() { return docExists ? 1 : 0; @@ -58,4 +60,89 @@ public int size() { public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } + + /** + * Creates a KNNVectorScriptDocValues object based on the provided parameters. + * + * @param values The DocIdSetIterator representing the vector values. + * @param fieldName The name of the field. + * @param vectorDataType The data type of the vector. + * @return A KNNVectorScriptDocValues object based on the type of the values. + * @throws IllegalArgumentException If the type of values is unsupported. + */ + public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + Objects.requireNonNull(values, "values must not be null"); + if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof FloatVectorValues) { + return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof BinaryDocValues) { + return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); + } else { + throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); + } + } + + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private final ByteVectorValues values; + + KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + byte[] bytes = values.vectorValue(); + float[] value = new float[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + value[i] = (float) bytes[i]; + } + return value; + } + } + + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private final FloatVectorValues values; + + KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return values.vectorValue(); + } + } + + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private final BinaryDocValues values; + + KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return getVectorDataType().getVectorFromDocValues(values.binaryValue()); + } + } + + /** + * Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type. + * + * @param fieldName The name of the field. + * @param type The data type of the vector. + * @return An empty KNNVectorScriptDocValues object. + */ + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 13d1bc64d6..72ffd2b66e 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -147,7 +147,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), @@ -257,7 +257,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), @@ -827,7 +827,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index a0df3ce640..66e2893c0e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -5,6 +5,15 @@ package org.opensearch.knn.index; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -13,7 +22,6 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.Directory; import org.junit.Assert; import org.junit.Before; @@ -24,6 +32,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; + private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 }; private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; @@ -32,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - createKNNVectorDocument(directory); + Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); + createKNNVectorDocument(directory, valuesClass); reader = DirectoryReader.open(directory); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME, - VectorDataType.FLOAT - ); + LeafReader leafReader = reader.getContext().leaves().get(0).reader(); + DocIdSetIterator vectorValues; + if (BinaryDocValues.class.equals(valuesClass)) { + vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); + } else if (ByteVectorValues.class.equals(valuesClass)) { + vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); + } else { + vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); + } + + scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory) throws IOException { + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( + Field field; + if (BinaryDocValues.class.equals(valuesClass)) { + field = new BinaryDocValuesField( MOCK_INDEX_FIELD_NAME, new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + ); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -83,4 +105,18 @@ public void testSize() throws IOException { public void testGet() throws IOException { expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); } + + public void testUnsupportedValues() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT) + ); + } + + public void testEmptyValues() throws IOException { + KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + assertEquals(0, values.size()); + scriptDocValues.setNextDocId(0); + assertEquals(0, values.size()); + } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 083b5b3704..c53fa4456b 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -7,7 +7,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Floats; import org.apache.http.util.EntityUtils; import lombok.SneakyThrows; import org.apache.commons.lang.math.RandomUtils; @@ -307,14 +306,14 @@ public void testIndexReopening() throws Exception { final float[] searchVector = TEST_QUERY_VECTORS[0]; final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length); - final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); + final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); closeIndex(INDEX_NAME); openIndex(INDEX_NAME); ensureGreen(INDEX_NAME); - final List knnResultsAfterIndexClosure = queryResults(searchVector, k); + final List knnResultsAfterIndexClosure = queryResults(searchVector, k); assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } @@ -365,7 +364,7 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws IOExc List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector); float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001); @@ -373,7 +372,7 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws IOExc } } - private List queryResults(final float[] searchVector, final int k) throws Exception { + private List queryResults(final float[] searchVector, final int k) throws Exception { final String responseBody = EntityUtils.toString( searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity() ); diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index cb15a0eb95..b76a26e69e 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -30,11 +30,9 @@ import java.io.IOException; import java.net.URL; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -115,7 +113,7 @@ public void testEndToEnd() throws IOException, InterruptedException { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 24dcf7b1c7..f8948db266 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -143,7 +142,7 @@ public void testEndToEnd() throws IOException, InterruptedException { List actualScores = parseSearchResponseScore(responseBody, fieldName1); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), spaceType1), actualScores.get(j), @@ -159,7 +158,7 @@ public void testEndToEnd() throws IOException, InterruptedException { actualScores = parseSearchResponseScore(responseBody, fieldName2); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType2), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 4423c85d8f..19270717d1 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 8c43a4acf3..22110accd0 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( + scriptDocValues = KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, VectorDataType.FLOAT diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 214ecd1584..58cdb31121 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -5,15 +5,18 @@ package org.opensearch.knn.plugin.script; -import org.opensearch.core.xcontent.MediaTypeRegistry; +import java.util.function.BiFunction; +import java.util.function.Function; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.apache.http.util.EntityUtils; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; @@ -22,6 +25,9 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.script.Script; import java.util.ArrayList; @@ -38,214 +44,19 @@ public class KNNScriptScoringIT extends KNNRestTestCase { public void testKNNL2ScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 2.0f, 2.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 4.0f, 4.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [2.0, 2.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.L2.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("2", "4", "3", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("2", results.get(0).getDocId()); - assertEquals("4", results.get(1).getDocId()); - assertEquals("3", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.L2); } public void testKNNL1ScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 4.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 5.0f, 5.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [1.0, 1.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.L1); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("2", "4", "3", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("2", results.get(0).getDocId()); - assertEquals("3", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.L1); } public void testKNNLInfScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 4.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 5.0f, 5.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [1.0, 1.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.LINF.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("3", "2", "4", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("3", results.get(0).getDocId()); - assertEquals("2", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.LINF); } public void testKNNCosineScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 1.0f, -1.0f }; - addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); - - Float[] f2 = { 1.0f, 0.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f2); - - Float[] f3 = { 1.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f3); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "query_value": [2.0, 2.0], - * "space_type": "L2" - * } - * - * - */ - float[] queryVector = { 2.0f, -2.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.COSINESIMIL.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("0", "1", "2"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(3, results.size()); - - // assert document order - assertEquals("0", results.get(0).getDocId()); - assertEquals("1", results.get(1).getDocId()); - assertEquals("2", results.get(2).getDocId()); + testKNNScriptScore(SpaceType.COSINESIMIL); } public void testKNNInvalidSourceScript() throws Exception { @@ -395,10 +206,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception { List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() .get("hits")).get("hits"); - List docIds = hits.stream().map(hit -> { - String id = ((String) ((Map) hit).get("_id")); - return id; - }).collect(Collectors.toList()); + List docIds = hits.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); // assert document order assertEquals("1", docIds.get(0)); assertEquals("0", docIds.get(1)); @@ -624,57 +432,7 @@ public void testHammingScriptScore_Base64() throws Exception { } public void testKNNInnerProdScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { -2.0f, -2.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 1.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 2.0f, 2.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 2.0f, -2.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "query_value": [1.0, 1.0], - * "space_type": "innerproduct", - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.INNER_PRODUCT.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("3", "2", "4", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("3", results.get(0).getDocId()); - assertEquals("2", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.INNER_PRODUCT); } public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { @@ -782,4 +540,121 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { // assert that the request cache was hit at second request assertEquals(1, secondQueryCacheMap.get("hit_count")); } + + private List createMappers(int dimensions) throws Exception { + return List.of( + createKnnIndexMapping(FIELD_NAME, dimensions), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + true + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false + ) + ); + } + + private float[] randomVector(int dimensions) { + final float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloat(); + } + return vector; + } + + private Map createDataset(Function scoreFunction, int dimensions, int numDocs) { + final Map dataset = new HashMap<>(numDocs); + for (int i = 0; i < numDocs; i++) { + final float[] vector = randomVector(dimensions); + final float score = scoreFunction.apply(vector); + dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score)); + } + return dataset; + } + + private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + FIELD_NAME, + Collections.emptyMap(), + queryVector.length, + VectorDataType.FLOAT, + null + ); + List target = new ArrayList<>(queryVector.length); + for (float f : queryVector) { + target.add(f); + } + KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create(spaceType.getValue(), target, knnVectorFieldType); + switch (spaceType) { + case L1: + return ((KNNScoringSpace.L1) knnScoringSpace).scoringMethod; + case L2: + return ((KNNScoringSpace.L2) knnScoringSpace).scoringMethod; + case LINF: + return ((KNNScoringSpace.LInf) knnScoringSpace).scoringMethod; + case COSINESIMIL: + return ((KNNScoringSpace.CosineSimilarity) knnScoringSpace).scoringMethod; + case INNER_PRODUCT: + return ((KNNScoringSpace.InnerProd) knnScoringSpace).scoringMethod; + default: + throw new IllegalArgumentException(); + } + } + + private void testKNNScriptScore(SpaceType spaceType) throws Exception { + final int dims = randomIntBetween(2, 10); + final float[] queryVector = randomVector(dims); + final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + for (String mapper : createMappers(dims)) { + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector); + } + } + + private void createIndexAndAssertScriptScore( + String mapper, + SpaceType spaceType, + BiFunction scoreFunction, + int dimensions, + float[] queryVector + ) throws Exception { + /* + * Create knn index and populate data + */ + createKnnIndex(INDEX_NAME, mapper); + Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10)); + for (Map.Entry entry : dataset.entrySet()) { + addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector()); + } + + /** + * Construct Search Request + */ + QueryBuilder qb = new MatchAllQueryBuilder(); + Map params = new HashMap<>(); + /* + * params": { + * "field": FIELD_NAME, + * "vector": queryVector + * } + */ + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType.getValue()); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertTrue(results.stream().allMatch(r -> dataset.get(r.getDocId()).equals(r))); + deleteKNNIndex(INDEX_NAME); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 15e3732b24..47e914d040 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -53,6 +53,10 @@ protected String createMapping(List properties) throws IOExcept builder.field("dimension", property.getDimension()); } + if (property.getDocValues() != null) { + builder.field("doc_values", property.getDocValues()); + } + if (property.getKnnMethodContext() != null) { builder.startObject(KNNConstants.KNN_METHOD); property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -554,12 +558,14 @@ public void testScriptedMetricIsSupported() throws Exception { public void testL2ScriptingWithLuceneBackedIndex() throws Exception { List properties = new ArrayList<>(); KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.NMSLIB, + KNNEngine.LUCENE, SpaceType.DEFAULT, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext) + new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") + .knnMethodContext(knnMethodContext) + .docValues(randomBoolean()) ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); @@ -585,6 +591,7 @@ static class MappingProperty { private String dimension; private KNNMethodContext knnMethodContext; + private Boolean docValues; MappingProperty(String name, String type) { this.name = name; @@ -601,6 +608,11 @@ MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) { return this; } + MappingProperty docValues(boolean docValues) { + this.docValues = docValues; + return this; + } + KNNMethodContext getKnnMethodContext() { return knnMethodContext; } @@ -616,5 +628,9 @@ String getName() { String getType() { return type; } + + Boolean getDocValues() { + return docValues; + } } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 6897091afb..60c4016480 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -243,10 +243,16 @@ protected List parseSearchResponse(String responseBody, String fieldN @SuppressWarnings("unchecked") List knnSearchResponses = hits.stream().map(hit -> { @SuppressWarnings("unchecked") - Float[] vector = Arrays.stream( - ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() - ).map(Object::toString).map(Float::valueOf).toArray(Float[]::new); - return new KNNResult((String) ((Map) hit).get("_id"), vector); + final float[] vector = Floats.toArray( + Arrays.stream( + ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() + ).map(Object::toString).map(Float::valueOf).collect(Collectors.toList()) + ); + return new KNNResult( + (String) ((Map) hit).get("_id"), + vector, + ((Double) ((Map) hit).get("_score")).floatValue() + ); }).collect(Collectors.toList()); return knnSearchResponses; @@ -323,20 +329,7 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions) thr * Utility to create a Knn Index Mapping with specific algorithm and engine */ protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine) throws IOException { - return XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimensions.toString()) - .startObject("method") - .field("name", algoName) - .field("engine", knnEngine) - .endObject() - .endObject() - .endObject() - .endObject() - .toString(); + return this.createKnnIndexMapping(fieldName, dimensions, algoName, knnEngine, SpaceType.DEFAULT.getValue()); } /** @@ -344,12 +337,27 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions, Str */ protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine, String spaceType) throws IOException { + return this.createKnnIndexMapping(fieldName, dimensions, algoName, knnEngine, spaceType, true); + } + + /** + * Utility to create a Knn Index Mapping with specific algorithm, engine, spaceType and docValues + */ + protected String createKnnIndexMapping( + String fieldName, + Integer dimensions, + String algoName, + String knnEngine, + String spaceType, + boolean docValues + ) throws IOException { return XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(fieldName) .field(KNNConstants.TYPE, KNNConstants.TYPE_KNN_VECTOR) .field(KNNConstants.DIMENSION, dimensions.toString()) + .field("doc_values", docValues) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, algoName) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType) @@ -474,7 +482,7 @@ protected void forceMergeKnnIndex(String index, int maxSegments) throws Exceptio /** * Add a single KNN Doc to an index */ - protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void addKnnDoc(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); @@ -1014,8 +1022,7 @@ public float[][] getIndexVectorsFromIndex(String testIndex, String testField, in int i = 0; for (KNNResult result : results) { - float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList())); - vectors[i++] = primitiveArray; + vectors[i++] = result.getVector(); } return vectors; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNResult.java b/src/testFixtures/java/org/opensearch/knn/KNNResult.java index 803c2ae720..ee2ba39f7e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNResult.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNResult.java @@ -5,20 +5,41 @@ package org.opensearch.knn; +import java.util.Arrays; +import java.util.Objects; + public class KNNResult { + private final static float delta = 1e-3f; + private String docId; - private Float[] vector; + private float[] vector; + private Float score; - public KNNResult(String docId, Float[] vector) { + public KNNResult(String docId, float[] vector, Float score) { this.docId = docId; this.vector = vector; + this.score = score; } public String getDocId() { return docId; } - public Float[] getVector() { + public float[] getVector() { return vector; } + + public Float getScore() { + return score; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KNNResult knnResult = (KNNResult) o; + return Objects.equals(docId, knnResult.docId) + && Arrays.equals(vector, knnResult.vector) + && (Float.compare(score, knnResult.score) == 0 || Math.abs(score - knnResult.score) <= delta); + } }