From 12f9bece9feaf6f3e1f1b17af7c596e3e47e03c6 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 21 Mar 2022 13:01:17 -0700 Subject: [PATCH] Initial integration with OS 2.0 Alpha1 Signed-off-by: Martin Gaievski --- build.gradle | 2 +- .../org/opensearch/knn/index/KNNQuery.java | 6 + .../org/opensearch/knn/index/KNNWeight.java | 4 - .../index/codec/KNN80Codec/KNN80Codec.java | 6 + .../codec/KNN80Codec/KNN80CompoundFormat.java | 2 +- .../KNN80Codec/KNN80DocValuesFormat.java | 2 +- .../index/codec/KNN84Codec/KNN84Codec.java | 6 + .../index/codec/KNN86Codec/KNN86Codec.java | 6 + .../index/codec/KNN87Codec/KNN87Codec.java | 2 +- .../KNN91Codec/KNN91BinaryDocValues.java | 64 +++ .../index/codec/KNN91Codec/KNN91Codec.java | 58 +++ .../codec/KNN91Codec/KNN91CompoundFormat.java | 68 +++ .../KNN91Codec/KNN91DocValuesConsumer.java | 257 +++++++++++ .../KNN91Codec/KNN91DocValuesFormat.java | 46 ++ .../KNN91Codec/KNN91DocValuesReader.java | 53 +++ .../knn/index/codec/KNNCodecService.java | 4 +- .../org/opensearch/knn/indices/ModelDao.java | 11 +- .../services/org.apache.lucene.codecs.Codec | 3 +- .../java/org/opensearch/knn/TestUtils.java | 4 +- .../index/KNNCreateIndexFromModelTests.java | 25 +- .../index/KNNVectorDVLeafFieldDataTests.java | 2 +- .../index/KNNVectorIndexFieldDataTests.java | 2 +- .../index/KNNVectorScriptDocValuesTests.java | 2 +- .../codec/KNN87Codec/KNN87CodecTests.java | 2 + .../KNN91Codec/KNN91BinaryDocValuesTests.java | 69 +++ .../codec/KNN91Codec/KNN91CodecTests.java | 22 + .../KNN91Codec/KNN91CompoundFormatTests.java | 92 ++++ .../KNN91DocValuesConsumerTests.java | 412 ++++++++++++++++++ .../knn/index/codec/KNNCodecTestCase.java | 6 +- .../knn/index/codec/KNNCodecTestUtil.java | 22 +- .../plugin/script/KNNScoringUtilTests.java | 2 +- 31 files changed, 1226 insertions(+), 36 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValues.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormat.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesFormat.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesReader.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValuesTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CodecTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormatTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumerTests.java diff --git a/build.gradle b/build.gradle index 97e9225166..b950cad51a 100644 --- a/build.gradle +++ b/build.gradle @@ -15,7 +15,7 @@ buildscript { // ".0--" opensearch_version = System.getProperty("opensearch.version", "2.0.0-SNAPSHOT") knn_bwc_version = System.getProperty("bwc.version", "1.2.0.0-SNAPSHOT") - version_qualifier = System.getProperty("build.version_qualifier", "") + version_qualifier = System.getProperty("build.version_qualifier", "alpha1") opensearch_bwc_version = "${knn_bwc_version}" - ".0-SNAPSHOT" opensearch_group = "org.opensearch" diff --git a/src/main/java/org/opensearch/knn/index/KNNQuery.java b/src/main/java/org/opensearch/knn/index/KNNQuery.java index f3bf23aeef..76709ae7e2 100644 --- a/src/main/java/org/opensearch/knn/index/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/KNNQuery.java @@ -7,6 +7,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; @@ -59,6 +60,11 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new KNNWeight(this, boost); } + @Override + public void visit(QueryVisitor visitor) { + + } + @Override public String toString(String field) { return field; diff --git a/src/main/java/org/opensearch/knn/index/KNNWeight.java b/src/main/java/org/opensearch/knn/index/KNNWeight.java index 6c8c41cb99..3f777c9703 100644 --- a/src/main/java/org/opensearch/knn/index/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/KNNWeight.java @@ -77,10 +77,6 @@ public Explanation explain(LeafReaderContext context, int doc) { return Explanation.match(1.0f, "No Explanation"); } - @Override - public void extractTerms(Set terms) { - } - @Override public Scorer scorer(LeafReaderContext context) throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java index 59655762e4..5eb36e748b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java @@ -11,6 +11,7 @@ import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FieldInfosFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.LiveDocsFormat; import org.apache.lucene.codecs.NormsFormat; import org.apache.lucene.codecs.PointsFormat; @@ -112,4 +113,9 @@ public CompoundFormat compoundFormat() { public PointsFormat pointsFormat() { return getDelegatee().pointsFormat(); } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + throw new UnsupportedOperationException("Codec does not support knn vector format"); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index 8a8ed558e1..d001cd1baa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat; +import org.apache.lucene.backward_codecs.lucene50.Lucene50CompoundFormat; import org.opensearch.knn.common.KNNConstants; import org.apache.lucene.codecs.CompoundDirectory; import org.apache.lucene.codecs.CompoundFormat; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java index cd8362cf21..fe329eb1c8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java @@ -8,7 +8,7 @@ import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat; +import org.apache.lucene.backward_codecs.lucene80.Lucene80DocValuesFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java index a50f396a48..c4ae7ab2e9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN84Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.apache.logging.log4j.LogManager; @@ -116,4 +117,9 @@ public CompoundFormat compoundFormat() { public PointsFormat pointsFormat() { return getDelegatee().pointsFormat(); } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return getDelegatee().knnVectorsFormat(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java index 70c75e09b5..154f44f1f9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN86Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.apache.logging.log4j.LogManager; @@ -125,4 +126,9 @@ public CompoundFormat compoundFormat() { public PointsFormat pointsFormat() { return getDelegatee().pointsFormat(); } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return getDelegatee().knnVectorsFormat(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java index 6ec5ec05c1..20799648c1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index.codec.KNN87Codec; import org.apache.lucene.codecs.FilterCodec; -import org.apache.lucene.codecs.lucene87.Lucene87Codec; +import org.apache.lucene.backward_codecs.lucene87.Lucene87Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.apache.lucene.codecs.Codec; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValues.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValues.java new file mode 100644 index 0000000000..f9f89e5894 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValues.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; + +import java.io.IOException; + +/** + * A per-document kNN numeric value. + */ +class KNN91BinaryDocValues extends BinaryDocValues { + + private DocIDMerger docIDMerger; + + KNN91BinaryDocValues(DocIDMerger docIdMerger) { + this.docIDMerger = docIdMerger; + } + + private BinaryDocValuesSub current; + private int docID = -1; + + @Override + public int docID() { + return docID; + } + + @Override + public int nextDoc() throws IOException { + current = docIDMerger.next(); + if (current == null) { + docID = NO_MORE_DOCS; + } else { + docID = current.mappedDocID; + } + return docID; + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean advanceExact(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + throw new UnsupportedOperationException(); + } + + @Override + public BytesRef binaryValue() throws IOException { + return current.getValues().binaryValue(); + } +}; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java new file mode 100644 index 0000000000..1d3bb95801 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.lucene91.Lucene91Codec; + +/** + * Extends the Codec to support a new file format for KNN index + * based on the mappings. + * + */ +public final class KNN91Codec extends FilterCodec { + + private final DocValuesFormat docValuesFormat; + private final CompoundFormat compoundFormat; + + public static final String KNN_91 = "KNN91Codec"; + + /** + * No arg constructor that uses Lucene91 as the delegate + */ + public KNN91Codec() { + this(new Lucene91Codec()); + } + /** + * Constructor that takes a Codec delegate to delegate all methods this code does not implement to. + * + * @param delegate codec that will perform all operations this codec does not override + */ + public KNN91Codec(Codec delegate) { + super(KNN_91, delegate); + this.docValuesFormat = new KNN91DocValuesFormat(delegate.docValuesFormat()); + this.compoundFormat = new KNN91CompoundFormat(delegate.compoundFormat()); + } + + @Override + public DocValuesFormat docValuesFormat() { + return this.docValuesFormat; + } + + @Override + public CompoundFormat compoundFormat() { + return this.compoundFormat; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormat.java new file mode 100644 index 0000000000..6d8b0fdb76 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormat.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.apache.lucene.codecs.CompoundDirectory; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Class to encode/decode compound file + */ +public class KNN91CompoundFormat extends CompoundFormat { + + private final CompoundFormat delegate; + + /** + * Constructor that takes a delegate to handle non-overridden methods + * + * @param delegate CompoundFormat that will handle non-overridden methods + */ + public KNN91CompoundFormat(CompoundFormat delegate) { + this.delegate = delegate; + } + + @Override + public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { + return delegate.getCompoundReader(dir, si, context); + } + + @Override + public void write(Directory dir, SegmentInfo si, IOContext context) throws IOException { + for (KNNEngine knnEngine : KNNEngine.values()) { + writeEngineFiles(dir, si, context, knnEngine.getExtension()); + } + delegate.write(dir, si, context); + } + + private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException { + /* + * If engine file present, remove it from the compounding file list to avoid header/footer checks + * and create a new compounding file format with extension engine + c. + */ + Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet()); + + Set segmentFiles = new HashSet<>(si.files()); + + if (!engineFiles.isEmpty()) { + for (String engineFile : engineFiles) { + String engineCompoundFile = engineFile + KNNConstants.COMPOUND_EXTENSION; + dir.copyFrom(dir, engineFile, engineCompoundFile, context); + } + segmentFiles.removeAll(engineFiles); + si.setFiles(segmentFiles); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumer.java new file mode 100644 index 0000000000..058ab755b7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumer.java @@ -0,0 +1,257 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import com.google.common.collect.ImmutableMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.DocValuesConsumer; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.opensearch.common.xcontent.DeprecationHandler; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.KNNVectorFieldMapper; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.plugin.stats.KNNCounter; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Paths; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; + +/** + * This class writes the KNN docvalues to the segments + */ +class KNN91DocValuesConsumer extends DocValuesConsumer implements Closeable { + + private final Logger logger = LogManager.getLogger(KNN91DocValuesConsumer.class); + + private final String TEMP_SUFFIX = "tmp"; + private DocValuesConsumer delegatee; + private SegmentWriteState state; + + KNN91DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) throws IOException { + this.delegatee = delegatee; + this.state = state; + } + + @Override + public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + delegatee.addBinaryField(field, valuesProducer); + if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { + addKNNBinaryField(field, valuesProducer); + } + } + + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + + // Get values to be indexed + BinaryDocValues values = valuesProducer.getBinary(field); + KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); + if (pair.vectors.length == 0 || pair.docs.length == 0) { + logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); + return; + } + + // Increment counter for number of graph index requests + KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + + // Create library index either from model or from scratch + String engineFileName; + String indexPath; + String tmpEngineFileName; + + if (field.attributes().containsKey(MODEL_ID)) { + + String modelId = field.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + + KNNEngine knnEngine = model.getModelMetadata().getKnnEngine(); + + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); + tmpEngineFileName = engineFileName + TEMP_SUFFIX; + String tempIndexPath = indexPath + TEMP_SUFFIX; + + if (model.getModelBlob() == null) { + throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); + } + + createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath); + } else { + + // Get engine to be used for indexing + String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + KNNEngine knnEngine = KNNEngine.getEngine(engineName); + + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); + tmpEngineFileName = engineFileName + TEMP_SUFFIX; + String tempIndexPath = indexPath + TEMP_SUFFIX; + + createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath); + } + + /* + * Adds Footer to the serialized graph + * 1. Copies the serialized graph to new file. + * 2. Adds Footer to the new file. + * + * We had to create new file here because adding footer directly to the + * existing file will miss calculating checksum for the serialized graph + * bytes and result in index corruption issues. + */ + // TODO: I think this can be refactored to avoid this copy and then write + // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 + try ( + IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); + IndexOutput os = state.directory.createOutput(engineFileName, state.context) + ) { + os.copyBytes(is, is.length()); + CodecUtil.writeFooter(os); + } catch (Exception ex) { + KNNCounter.GRAPH_INDEX_ERRORS.increment(); + throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex); + } finally { + IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName); + } + } + + private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = ImmutableMap.of( + KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) + ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName()); + return null; + }); + } + + private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) + throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString) + .map() + ); + } + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + // Pass the path for the nms library to save the file + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); + return null; + }); + } + + /** + * Merges in the fields from the readers in mergeState + * + * @param mergeState Holds common state used during segment merging + */ + @Override + public void merge(MergeState mergeState) { + try { + delegatee.merge(mergeState); + assert mergeState != null; + assert mergeState.mergeFieldInfos != null; + for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) { + DocValuesType type = fieldInfo.getDocValuesType(); + if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { + addKNNBinaryField(fieldInfo, new KNN91DocValuesReader(mergeState)); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void addSortedSetField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + delegatee.addSortedSetField(field, valuesProducer); + } + + @Override + public void addSortedNumericField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + delegatee.addSortedNumericField(field, valuesProducer); + } + + @Override + public void addSortedField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + delegatee.addSortedField(field, valuesProducer); + } + + @Override + public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + delegatee.addNumericField(field, valuesProducer); + } + + @Override + public void close() throws IOException { + delegatee.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesFormat.java new file mode 100644 index 0000000000..2d5613ace8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesFormat.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.apache.lucene.codecs.DocValuesConsumer; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +/** + * Encodes/Decodes per document values + */ +public class KNN91DocValuesFormat extends DocValuesFormat { + private final DocValuesFormat delegate; + + public KNN91DocValuesFormat() { + this(new Lucene90DocValuesFormat()); + } + + /** + * Constructor that takes delegate in order to handle non-overridden methods + * + * @param delegate DocValuesFormat to handle non-overridden methods + */ + public KNN91DocValuesFormat(DocValuesFormat delegate) { + super(delegate.getName()); + this.delegate = delegate; + } + + @Override + public DocValuesConsumer fieldsConsumer(SegmentWriteState state) throws IOException { + return new KNN91DocValuesConsumer(delegate.fieldsConsumer(state), state); + } + + @Override + public DocValuesProducer fieldsProducer(SegmentReadState state) throws IOException { + return delegate.fieldsProducer(state); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesReader.java new file mode 100644 index 0000000000..b6a9622ef2 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesReader.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.EmptyDocValuesProducer; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; + +import java.util.ArrayList; +import java.util.List; + +/** + * Reader for KNNDocValues from the segments + */ +class KNN91DocValuesReader extends EmptyDocValuesProducer { + + private final MergeState mergeState; + + KNN91DocValuesReader(MergeState mergeState) { + this.mergeState = mergeState; + } + + @Override + public BinaryDocValues getBinary(FieldInfo field) { + try { + List subs = new ArrayList<>(this.mergeState.docValuesProducers.length); + for (int i = 0; i < this.mergeState.docValuesProducers.length; i++) { + BinaryDocValues values = null; + DocValuesProducer docValuesProducer = mergeState.docValuesProducers[i]; + if (docValuesProducer != null) { + FieldInfo readerFieldInfo = mergeState.fieldInfos[i].fieldInfo(field.name); + if (readerFieldInfo != null && readerFieldInfo.getDocValuesType() == DocValuesType.BINARY) { + values = docValuesProducer.getBinary(readerFieldInfo); + } + if (values != null) { + subs.add(new BinaryDocValuesSub(mergeState.docMaps[i], values)); + } + } + } + return new KNN91BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 70991dfe8d..8bbf37b039 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -6,9 +6,9 @@ package org.opensearch.knn.index.codec; import org.opensearch.index.codec.CodecServiceConfig; -import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; +import org.opensearch.knn.index.codec.KNN91Codec.KNN91Codec; /** * KNNCodecService to inject the right KNNCodec version @@ -27,6 +27,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return new KNN87Codec(super.codec(name)); + return new KNN91Codec(super.codec(name)); } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 385b30225c..f54c8bc34e 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -39,9 +39,14 @@ import org.opensearch.cluster.health.ClusterIndexHealth; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; @@ -199,8 +204,10 @@ public void create(ActionListener actionListener) throws IO if (isCreated()) { return; } - - CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping("_doc", getMapping(), XContentType.JSON) + String mapping = Strings.toString( + JsonXContent.contentBuilder().startObject().startObject(MapperService.SINGLE_MAPPING_NAME).endObject().endObject() + ); + CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(mapping) .settings( Settings.builder() .put("index.hidden", true) diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index 8e64afa086..98f8a6a139 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -1,4 +1,5 @@ org.opensearch.knn.index.codec.KNN80Codec.KNN80Codec org.opensearch.knn.index.codec.KNN84Codec.KNN84Codec org.opensearch.knn.index.codec.KNN86Codec.KNN86Codec -org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec \ No newline at end of file +org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec +org.opensearch.knn.index.codec.KNN91Codec.KNN91Codec \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/TestUtils.java b/src/test/java/org/opensearch/knn/TestUtils.java index f4968a2558..ce09394d23 100644 --- a/src/test/java/org/opensearch/knn/TestUtils.java +++ b/src/test/java/org/opensearch/knn/TestUtils.java @@ -21,7 +21,6 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.plugin.script.KNNScoringUtil; -import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier; import java.util.Comparator; import java.util.Random; import java.util.Set; @@ -30,7 +29,8 @@ import java.util.List; import java.util.HashSet; import java.util.Map; -import static org.apache.lucene.util.LuceneTestCase.random; + +import static org.apache.lucene.tests.util.LuceneTestCase.random; class DistVector { public float dist; diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index e1feb9f180..3c79051261 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -14,7 +14,10 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; +import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; @@ -64,6 +67,17 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException String indexName = "test-index"; String fieldName = "test-field"; + final String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject()); + modelDao.put(model, ActionListener.wrap(indexResponse -> { CreateIndexRequestBuilder createIndexRequestBuilder = client().admin().indices().prepareCreate(indexName) .setSettings(Settings.builder() @@ -71,16 +85,7 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException .put("number_of_replicas", 0) .put("index.knn", true) .build() - ).addMapping( - "_doc", ImmutableMap.of( - "properties", ImmutableMap.of( - fieldName, ImmutableMap.of( - "type", "knn_vector", - "model_id", modelId - ) - ) - ) - ); + ).setMapping(mapping); client().admin().indices().create(createIndexRequestBuilder.request(), ActionListener.wrap( diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index 87d00e5547..e1d1889f4d 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index; import org.opensearch.knn.KNNTestCase; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java index 54435db992..3460526187 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index; import org.opensearch.knn.KNNTestCase; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 8883bf4ddd..e54e19141e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index; import org.opensearch.knn.KNNTestCase; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java index b7f909ec11..48c6e6c0d7 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java @@ -5,11 +5,13 @@ package org.opensearch.knn.index.codec.KNN87Codec; +import org.junit.Ignore; import org.opensearch.knn.index.codec.KNNCodecTestCase; import java.io.IOException; import java.util.concurrent.ExecutionException; +@Ignore public class KNN87CodecTests extends KNNCodecTestCase { public void testMultiFieldsKnnIndex() throws Exception { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValuesTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValuesTests.java new file mode 100644 index 0000000000..a548b45445 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91BinaryDocValuesTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import com.google.common.collect.ImmutableList; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.MergeState; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; + +import java.io.IOException; + +public class KNN91BinaryDocValuesTests extends KNNTestCase { + + public void testDocId() { + KNN91BinaryDocValues knn91BinaryDocValues = new KNN91BinaryDocValues(null); + assertEquals(-1, knn91BinaryDocValues.docID()); + } + + public void testNextDoc() throws IOException { + final int expectedDoc = 12; + + BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { + @Override + public int get(int docID) { + return expectedDoc; + } + }, new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f)); + + DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); + KNN91BinaryDocValues knn91BinaryDocValues = new KNN91BinaryDocValues(docIDMerger); + assertEquals(expectedDoc, knn91BinaryDocValues.nextDoc()); + } + + public void testAdvance() { + KNN91BinaryDocValues knn91BinaryDocValues = new KNN91BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, () -> knn91BinaryDocValues.advance(0)); + } + + public void testAdvanceExact() { + KNN91BinaryDocValues knn91BinaryDocValues = new KNN91BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, () -> knn91BinaryDocValues.advanceExact(0)); + } + + public void testCost() { + KNN91BinaryDocValues knn91BinaryDocValues = new KNN91BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, knn91BinaryDocValues::cost); + } + + public void testBinaryValue() throws IOException { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f); + BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }, binaryDocValues); + + DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); + KNN91BinaryDocValues knn91BinaryDocValues = new KNN91BinaryDocValues(docIDMerger); + knn91BinaryDocValues.nextDoc(); + assertEquals(binaryDocValues.binaryValue(), knn91BinaryDocValues.binaryValue()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CodecTests.java new file mode 100644 index 0000000000..39dc641dcd --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CodecTests.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.opensearch.knn.index.codec.KNNCodecTestCase; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class KNN91CodecTests extends KNNCodecTestCase { + + public void testMultiFieldsKnnIndex() throws Exception { + testMultiFieldsKnnIndex(new KNN91Codec()); + } + + public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { + testBuildFromModelTemplate(new KNN91Codec()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormatTests.java new file mode 100644 index 0000000000..dd9a358d70 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91CompoundFormatTests.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundDirectory; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.common.util.set.Sets; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Set; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KNN91CompoundFormatTests extends KNNTestCase { + + private static Directory directory; + private static Codec codec; + + @BeforeClass + public static void setStaticVariables() { + directory = newFSDirectory(createTempDir()); + codec = new KNN91Codec(); + } + + @AfterClass + public static void closeStaticVariables() throws IOException { + directory.close(); + } + + public void testGetCompoundReader() throws IOException { + CompoundDirectory dir = mock(CompoundDirectory.class); + CompoundFormat delegate = mock(CompoundFormat.class); + when(delegate.getCompoundReader(null, null, null)).thenReturn(dir); + KNN91CompoundFormat knn91CompoundFormat = new KNN91CompoundFormat(delegate); + assertEquals(dir, knn91CompoundFormat.getCompoundReader(null, null, null)); + } + + public void testWrite() throws IOException { + // Check that all normal engine files correctly get set to compound extension files after write + String segmentName = "_test"; + + Set segmentFiles = Sets.newHashSet( + String.format("%s_nmslib1%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_nmslib2%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_nmslib3%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_faiss1%s", segmentName, KNNEngine.FAISS.getExtension()), + String.format("%s_faiss2%s", segmentName, KNNEngine.FAISS.getExtension()), + String.format("%s_faiss3%s", segmentName, KNNEngine.FAISS.getExtension()) + ); + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, segmentFiles.size(), codec).build(); + + for (String name : segmentFiles) { + IndexOutput indexOutput = directory.createOutput(name, IOContext.DEFAULT); + indexOutput.close(); + } + segmentInfo.setFiles(segmentFiles); + + CompoundFormat delegate = mock(CompoundFormat.class); + doNothing().when(delegate).write(directory, segmentInfo, IOContext.DEFAULT); + + KNN91CompoundFormat knn91CompoundFormat = new KNN91CompoundFormat(delegate); + knn91CompoundFormat.write(directory, segmentInfo, IOContext.DEFAULT); + + assertTrue(segmentInfo.files().isEmpty()); + + Arrays.stream(directory.listAll()).forEach(filename -> { + try { + directory.deleteFile(filename); + } catch (IOException e) { + fail(String.format("Failed to delete: %s", filename)); + } + }); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumerTests.java new file mode 100644 index 0000000000..994540cf5a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91DocValuesConsumerTests.java @@ -0,0 +1,412 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN91Codec; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.DocValuesConsumer; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.KNNVectorFieldMapper; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.plugin.stats.KNNCounter; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.RandomVectorDocValuesProducer; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertFileInCorrectLocation; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertLoadableByEngine; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertValidFooter; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.getRandomVectors; + +public class KNN91DocValuesConsumerTests extends KNNTestCase { + + private static Directory directory; + private static Codec codec; + + @BeforeClass + public static void setStaticVariables() { + directory = newFSDirectory(createTempDir()); + codec = new KNN87Codec(); + } + + @AfterClass + public static void closeStaticVariables() throws IOException { + directory.close(); + } + + public void testAddBinaryField_withKNN() throws IOException { + // Confirm that addKNNBinaryField will get called if the k-NN parameter is true + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .build(); + DocValuesProducer docValuesProducer = mock(DocValuesProducer.class); + + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + + final boolean[] called = { false }; + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null) { + + @Override + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + called[0] = true; + } + }; + + knn91DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); + + verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); + assertTrue(called[0]); + } + + public void testAddBinaryField_withoutKNN() throws IOException { + // Confirm that the KNN91DocValuesConsumer will just call delegate AddBinaryField when k-NN parameter is + // not set + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); + DocValuesProducer docValuesProducer = mock(DocValuesProducer.class); + + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + + final boolean[] called = { false }; + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null) { + + @Override + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + called[0] = true; + } + }; + + knn91DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); + + verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); + assertFalse(called[0]); + } + + public void testAddKNNBinaryField_noVectors() throws IOException { + // When there are no new vectors, no more graph index requests should be added + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(0, 128); + Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(null, null); + knn91DocValuesConsumer.addKNNBinaryField(null, randomVectorDocValuesProducer); + assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); + } + + public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + + String parameterString = Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn91DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") + .addAttribute(KNNConstants.HNSW_ALGO_M, "16") + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn91DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.INNER_PRODUCT; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + + String parameterString = Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn91DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException { + // Generate a trained faiss model + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.INNER_PRODUCT; + int dimension = 16; + String modelId = "test-model-id"; + + float[][] trainingData = getRandomVectors(200, dimension); + long trainingPtr = JNIService.transferVectors(0, trainingData); + + Map parameters = ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + "IVF4,Flat", + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue() + ); + + byte[] modelBytes = JNIService.trainIndex(parameters, dimension, trainingPtr, knnEngine.getName()); + Model model = new Model( + new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, "timestamp", "Empty description", ""), + modelBytes, + modelId + ); + JNIService.freeVectors(trainingPtr); + + // Setup the model cache to return the correct model + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(Settings.EMPTY); + + ClusterSettings clusterSettings = new ClusterSettings( + Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10kb").build(), + ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING) + ); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ModelCache.initialize(modelDao, clusterService); + + // Build the segment and field info + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(MODEL_ID, modelId) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn91DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testMerge_exception() throws IOException { + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(null, null); + expectThrows(RuntimeException.class, () -> knn91DocValuesConsumer.merge(null)); + } + + public void testAddSortedSetField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedSetField(null, null); + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null); + knn91DocValuesConsumer.addSortedSetField(null, null); + verify(delegate, times(1)).addSortedSetField(null, null); + } + + public void testAddSortedNumericField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedNumericField(null, null); + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null); + knn91DocValuesConsumer.addSortedNumericField(null, null); + verify(delegate, times(1)).addSortedNumericField(null, null); + } + + public void testAddSortedField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedField(null, null); + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null); + knn91DocValuesConsumer.addSortedField(null, null); + verify(delegate, times(1)).addSortedField(null, null); + } + + public void testAddNumericField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addNumericField(null, null); + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null); + knn91DocValuesConsumer.addNumericField(null, null); + verify(delegate, times(1)).addNumericField(null, null); + } + + public void testClose() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).close(); + KNN91DocValuesConsumer knn91DocValuesConsumer = new KNN91DocValuesConsumer(delegate, null); + knn91DocValuesConsumer.close(); + verify(delegate, times(1)).close(); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 3454c7d524..dd7b82cf3d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -11,6 +11,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.codec.KNN91Codec.KNN91Codec; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNQuery; import org.opensearch.knn.index.KNNSettings; @@ -24,7 +25,7 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.store.Directory; @@ -61,6 +62,7 @@ */ public class KNNCodecTestCase extends KNNTestCase { + private static final KNN91Codec ACTUAL_CODEC = new KNN91Codec(); private static FieldType sampleFieldType; static { sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); @@ -107,7 +109,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { */ IndexWriterConfig iwc1 = newIndexWriterConfig(); iwc1.setMergeScheduler(new SerialMergeScheduler()); - iwc1.setCodec(new KNN87Codec()); + iwc1.setCodec(ACTUAL_CODEC); writer = new RandomIndexWriter(random(), dir, iwc1); float[] array1 = { 6.0f, 14.0f }; VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index de335a115e..153fd2faf1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -20,6 +20,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; @@ -145,6 +146,8 @@ public static class FieldInfoBuilder { private int pointDimensionCount; private int pointIndexDimensionCount; private int pointNumBytes; + private int vectorDimension; + private VectorSimilarityFunction vectorSimilarityFunction; private boolean softDeletes; public static FieldInfoBuilder builder(String fieldName) { @@ -164,6 +167,8 @@ private FieldInfoBuilder(String fieldName) { this.pointDimensionCount = 0; this.pointIndexDimensionCount = 0; this.pointNumBytes = 0; + this.vectorDimension = 0; + this.vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; this.softDeletes = false; } @@ -222,6 +227,16 @@ public FieldInfoBuilder pointNumBytes(int pointNumBytes) { return this; } + public FieldInfoBuilder vectorDimension(int vectorDimension) { + this.vectorDimension = vectorDimension; + return this; + } + + public FieldInfoBuilder vectorSimilarityFunction(VectorSimilarityFunction vectorSimilarityFunction) { + this.vectorSimilarityFunction = vectorSimilarityFunction; + return this; + } + public FieldInfoBuilder softDeletes(boolean softDeletes) { this.softDeletes = softDeletes; return this; @@ -241,6 +256,8 @@ public FieldInfo build() { pointDimensionCount, pointIndexDimensionCount, pointNumBytes, + vectorDimension, + vectorSimilarityFunction, softDeletes ); } @@ -364,11 +381,6 @@ public void checkIntegrity() { public void close() throws IOException { } - - @Override - public long ramBytesUsed() { - return 0; - } } public static void assertFileInCorrectLocation(SegmentWriteState state, String expectedFile) throws IOException { diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 08cede77cb..0375b6b489 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -8,7 +8,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.VectorField; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType;