diff --git a/CHANGELOG.md b/CHANGELOG.md index 5daf3b564c..c594766f32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) +* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 8dc8d1f853..1f21ba6ea4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -41,17 +41,18 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - DocIdSetIterator values = null; FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); - System.out.println(fieldInfo); + if (fieldInfo == null) { + return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); + } + + DocIdSetIterator values = null; if (fieldInfo.hasVectorValues()) { values = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 ? reader.getFloatVectorValues(fieldName) : reader.getByteVectorValues(fieldName); - System.out.println("use vector values"); } else { values = DocValues.getBinary(reader, fieldName); - System.out.println("use binary values"); } return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 496997bc61..c733c534e3 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index; import java.io.IOException; +import java.util.Objects; import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -60,7 +61,17 @@ public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } + /** + * Creates a KNNVectorScriptDocValues object based on the provided parameters. + * + * @param values The DocIdSetIterator representing the vector values. + * @param fieldName The name of the field. + * @param vectorDataType The data type of the vector. + * @return A KNNVectorScriptDocValues object based on the type of the values. + * @throws IllegalArgumentException If the type of values is unsupported. + */ public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + Objects.requireNonNull(values, "values must not be null"); if (values instanceof ByteVectorValues) { return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); } else if (values instanceof FloatVectorValues) { @@ -118,4 +129,20 @@ protected float[] doGetValue() throws IOException { return getVectorDataType().getVectorFromDocValues(values.binaryValue()); } } + + /** + * Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type. + * + * @param fieldName The name of the field. + * @param type The data type of the vector. + * @return An empty KNNVectorScriptDocValues object. + */ + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 2e3531b18c..66e2893c0e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -5,6 +5,15 @@ package org.opensearch.knn.index; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -13,7 +22,6 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.Directory; import org.junit.Assert; import org.junit.Before; @@ -24,6 +32,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; + private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 }; private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; @@ -32,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - createKNNVectorDocument(directory); + Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); + createKNNVectorDocument(directory, valuesClass); reader = DirectoryReader.open(directory); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = KNNVectorScriptDocValues.create( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME, - VectorDataType.FLOAT - ); + LeafReader leafReader = reader.getContext().leaves().get(0).reader(); + DocIdSetIterator vectorValues; + if (BinaryDocValues.class.equals(valuesClass)) { + vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); + } else if (ByteVectorValues.class.equals(valuesClass)) { + vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); + } else { + vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); + } + + scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory) throws IOException { + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( + Field field; + if (BinaryDocValues.class.equals(valuesClass)) { + field = new BinaryDocValuesField( MOCK_INDEX_FIELD_NAME, new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + ); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -83,4 +105,18 @@ public void testSize() throws IOException { public void testGet() throws IOException { expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); } + + public void testUnsupportedValues() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT) + ); + } + + public void testEmptyValues() throws IOException { + KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + assertEquals(0, values.size()); + scriptDocValues.setNextDocId(0); + assertEquals(0, values.size()); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 59c4f8c0ec..aa889100de 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -5,8 +5,10 @@ package org.opensearch.knn.plugin.script; +import java.io.IOException; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.opensearch.client.Request; @@ -21,6 +23,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.script.Script; import java.util.ArrayList; @@ -35,12 +38,29 @@ import static org.hamcrest.Matchers.containsString; public class KNNScriptScoringIT extends KNNRestTestCase { + private void randomCreateKNNIndex() throws IOException { + if (randomBoolean()) { + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + } else { + createKnnIndex( + INDEX_NAME, + createKnnIndexMapping( + FIELD_NAME, + 2, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + randomBoolean() + ) + ); + } + } public void testKNNL2ScriptScore() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -93,7 +113,7 @@ public void testKNNL1ScriptScore() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -146,7 +166,7 @@ public void testKNNLInfScriptScore() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -199,7 +219,7 @@ public void testKNNCosineScriptScore() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 1.0f, -1.0f }; addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); @@ -251,7 +271,7 @@ public void testKNNInvalidSourceScript() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); /** * Construct Search Request @@ -293,7 +313,7 @@ public void testInvalidSpace() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); /** * Construct Search Request @@ -316,7 +336,7 @@ public void testMissingParamsInScript() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); /** * Construct Search Request @@ -349,7 +369,7 @@ public void testUnequalDimensions() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 1.0f, -1.0f }; addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); @@ -372,7 +392,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 1.0f, 1.0f }; addDocWithNumericField(INDEX_NAME, "0", "price", 10); addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -636,7 +656,7 @@ public void testKNNInnerProdScriptScore() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { -2.0f, -2.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -690,7 +710,7 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { /* * Create knn index and populate data */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + randomCreateKNNIndex(); Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 5fa88b0a5f..5325d12053 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -53,6 +53,10 @@ protected String createMapping(List properties) throws IOExcept builder.field("dimension", property.getDimension()); } + if (property.getDocValues() != null) { + builder.field("doc_values", property.getDocValues()); + } + if (property.getKnnMethodContext() != null) { builder.startObject(KNNConstants.KNN_METHOD); property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -554,12 +558,14 @@ public void testScriptedMetricIsSupported() throws Exception { public void testL2ScriptingWithLuceneBackedIndex() throws Exception { List properties = new ArrayList<>(); KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.NMSLIB, + KNNEngine.LUCENE, SpaceType.DEFAULT, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext) + new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") + .knnMethodContext(knnMethodContext) + .docValues(randomBoolean()) ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); @@ -585,6 +591,7 @@ static class MappingProperty { private String dimension; private KNNMethodContext knnMethodContext; + private Boolean docValues; MappingProperty(String name, String type) { this.name = name; @@ -601,6 +608,11 @@ MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) { return this; } + MappingProperty docValues(boolean docValues) { + this.docValues = docValues; + return this; + } + KNNMethodContext getKnnMethodContext() { return knnMethodContext; } @@ -616,5 +628,9 @@ String getName() { String getType() { return type; } + + Boolean getDocValues() { + return docValues; + } } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 0b6ae3a5e8..68900102b2 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -329,20 +329,7 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions) thr * Utility to create a Knn Index Mapping with specific algorithm and engine */ protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine) throws IOException { - return XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimensions.toString()) - .startObject("method") - .field("name", algoName) - .field("engine", knnEngine) - .endObject() - .endObject() - .endObject() - .endObject() - .toString(); + return this.createKnnIndexMapping(fieldName, dimensions, algoName, knnEngine, SpaceType.DEFAULT.getValue()); } /** @@ -350,12 +337,27 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions, Str */ protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine, String spaceType) throws IOException { + return this.createKnnIndexMapping(fieldName, dimensions, algoName, knnEngine, spaceType, true); + } + + /** + * Utility to create a Knn Index Mapping with specific algorithm, engine, spaceType and docValues + */ + protected String createKnnIndexMapping( + String fieldName, + Integer dimensions, + String algoName, + String knnEngine, + String spaceType, + boolean docValues + ) throws IOException { return XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(fieldName) .field(KNNConstants.TYPE, KNNConstants.TYPE_KNN_VECTOR) .field(KNNConstants.DIMENSION, dimensions.toString()) + .field("doc_values", docValues) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, algoName) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType)