diff --git a/gradle.properties b/gradle.properties index 8c7ff252a..86e266f28 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,3 +4,9 @@ # version=1.0.0 + +org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java index 827f0a3d1..832737a6d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.opensearch.knn.index.codec.BinaryDocValuesSub; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.util.BytesRef; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index ca581b6ae..8a8ed558e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -5,10 +5,8 @@ package org.opensearch.knn.index.codec.KNN80Codec; +import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat; import org.opensearch.knn.common.KNNConstants; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CompoundDirectory; import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.index.SegmentInfo; @@ -26,14 +24,24 @@ */ public class KNN80CompoundFormat extends CompoundFormat { - private final Logger logger = LogManager.getLogger(KNN80CompoundFormat.class); + private final CompoundFormat delegate; public KNN80CompoundFormat() { + this.delegate = new Lucene50CompoundFormat(); + } + + /** + * Constructor that takes a delegate to handle non-overridden methods + * + * @param delegate CompoundFormat that will handle non-overridden methods + */ + public KNN80CompoundFormat(CompoundFormat delegate) { + this.delegate = delegate; } @Override public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { - return Codec.getDefault().compoundFormat().getCompoundReader(dir, si, context); + return delegate.getCompoundReader(dir, si, context); } @Override @@ -41,17 +49,15 @@ public void write(Directory dir, SegmentInfo si, IOContext context) throws IOExc for (KNNEngine knnEngine : KNNEngine.values()) { writeEngineFiles(dir, si, context, knnEngine.getExtension()); } - Codec.getDefault().compoundFormat().write(dir, si, context); + delegate.write(dir, si, context); } - private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) - throws IOException { + private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException { /* * If engine file present, remove it from the compounding file list to avoid header/footer checks * and create a new compounding file format with extension engine + c. */ - Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)) - .collect(Collectors.toSet()); + Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet()); Set segmentFiles = new HashSet<>(si.files()); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index d4bc662c5..0bbd6e9ac 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -67,108 +67,117 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { @Override public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { delegatee.addBinaryField(field, valuesProducer); - addKNNBinaryField(field, valuesProducer); + if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { + addKNNBinaryField(field, valuesProducer); + } } public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { - // Get values to be indexed - BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.vectors.length == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); - return; - } + // Get values to be indexed + BinaryDocValues values = valuesProducer.getBinary(field); + KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); + if (pair.vectors.length == 0 || pair.docs.length == 0) { + logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); + return; + } - // Create library index either from model or from scratch - String engineFileName; - String indexPath; - String tmpEngineFileName; + // Increment counter for number of graph index requests + KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - if (field.attributes().containsKey(MODEL_ID)) { + // Create library index either from model or from scratch + String engineFileName; + String indexPath; + String tmpEngineFileName; - String modelId = field.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); + if (field.attributes().containsKey(MODEL_ID)) { - KNNEngine knnEngine = model.getModelMetadata().getKnnEngine(); + String modelId = field.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(), - field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName).toString(); - tmpEngineFileName = engineFileName + TEMP_SUFFIX; - String tempIndexPath = indexPath + TEMP_SUFFIX; + KNNEngine knnEngine = model.getModelMetadata().getKnnEngine(); - if (model.getModelBlob() == null) { - throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); - } + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); + tmpEngineFileName = engineFileName + TEMP_SUFFIX; + String tempIndexPath = indexPath + TEMP_SUFFIX; - createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath); - } else { + if (model.getModelBlob() == null) { + throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); + } - // Get engine to be used for indexing - String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); - KNNEngine knnEngine = KNNEngine.getEngine(engineName); + createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath); + } else { - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(), - field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName).toString(); - tmpEngineFileName = engineFileName + TEMP_SUFFIX; - String tempIndexPath = indexPath + TEMP_SUFFIX; + // Get engine to be used for indexing + String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + KNNEngine knnEngine = KNNEngine.getEngine(engineName); - createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath); - } + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); + tmpEngineFileName = engineFileName + TEMP_SUFFIX; + String tempIndexPath = indexPath + TEMP_SUFFIX; - /* - * Adds Footer to the serialized graph - * 1. Copies the serialized graph to new file. - * 2. Adds Footer to the new file. - * - * We had to create new file here because adding footer directly to the - * existing file will miss calculating checksum for the serialized graph - * bytes and result in index corruption issues. - */ - //TODO: I think this can be refactored to avoid this copy and then write - // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 - try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); - IndexOutput os = state.directory.createOutput(engineFileName, state.context)) { - os.copyBytes(is, is.length()); - CodecUtil.writeFooter(os); - } catch (Exception ex) { - KNNCounter.GRAPH_INDEX_ERRORS.increment(); - throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex); - } finally { - IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName); - } + createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath); + } + + /* + * Adds Footer to the serialized graph + * 1. Copies the serialized graph to new file. + * 2. Adds Footer to the new file. + * + * We had to create new file here because adding footer directly to the + * existing file will miss calculating checksum for the serialized graph + * bytes and result in index corruption issues. + */ + // TODO: I think this can be refactored to avoid this copy and then write + // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 + try ( + IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); + IndexOutput os = state.directory.createOutput(engineFileName, state.context) + ) { + os.copyBytes(is, is.length()); + CodecUtil.writeFooter(os); + } catch (Exception ex) { + KNNCounter.GRAPH_INDEX_ERRORS.increment(); + throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex); + } finally { + IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName); } } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, - String indexPath) { - Map parameters = ImmutableMap.of(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - AccessController.doPrivileged( - (PrivilegedAction) () -> { - JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, - knnEngine.getName()); - return null; - } + private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = ImmutableMap.of( + KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName()); + return null; + }); } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, - String indexPath) throws IOException { + private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) + throws IOException { Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); // parametersString will be null when legacy mapper is used if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, - SpaceType.DEFAULT.getValue())); + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); Map algoParams = new HashMap<>(); @@ -183,22 +192,20 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa parameters.put(PARAMETERS, algoParams); } else { parameters.putAll( - XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map() + XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString) + .map() ); } // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); // Pass the path for the nms library to save the file - AccessController.doPrivileged( - (PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); - return null; - } - ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); + return null; + }); } /** @@ -214,7 +221,7 @@ public void merge(MergeState mergeState) { assert mergeState.mergeFieldInfos != null; for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) { DocValuesType type = fieldInfo.getDocValuesType(); - if (type == DocValuesType.BINARY) { + if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState)); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java index 9114f38ee..cd8362cf2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java @@ -5,11 +5,10 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -19,11 +18,20 @@ * Encodes/Decodes per document values */ public class KNN80DocValuesFormat extends DocValuesFormat { - private final Logger logger = LogManager.getLogger(KNN80DocValuesFormat.class); - private final DocValuesFormat delegate = DocValuesFormat.forName(KNN80Codec.LUCENE_80); + private final DocValuesFormat delegate; public KNN80DocValuesFormat() { - super(KNN80Codec.LUCENE_80); + this(new Lucene80DocValuesFormat()); + } + + /** + * Constructor that takes delegate in order to handle non-overridden methods + * + * @param delegate DocValuesFormat to handle non-overridden methods + */ + public KNN80DocValuesFormat(DocValuesFormat delegate) { + super(delegate.getName()); + this.delegate = delegate; } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java index 943b39abc..ccfaa68fc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.opensearch.knn.index.codec.BinaryDocValuesSub; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; @@ -22,7 +22,7 @@ */ class KNN80DocValuesReader extends EmptyDocValuesProducer { - private MergeState mergeState; + private final MergeState mergeState; KNN80DocValuesReader(MergeState mergeState) { this.mergeState = mergeState; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java index 6e2e897e0..6ec5ec05c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java @@ -5,124 +5,53 @@ package org.opensearch.knn.index.codec.KNN87Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.lucene87.Lucene87Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; -import org.apache.lucene.codecs.FieldInfosFormat; -import org.apache.lucene.codecs.LiveDocsFormat; -import org.apache.lucene.codecs.NormsFormat; -import org.apache.lucene.codecs.PointsFormat; -import org.apache.lucene.codecs.PostingsFormat; -import org.apache.lucene.codecs.SegmentInfoFormat; -import org.apache.lucene.codecs.StoredFieldsFormat; -import org.apache.lucene.codecs.TermVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; /** * Extends the Codec to support a new file format for KNN index * based on the mappings. * */ -public final class KNN87Codec extends Codec { +public final class KNN87Codec extends FilterCodec { - private static final Logger logger = LogManager.getLogger(KNN87Codec.class); private final DocValuesFormat docValuesFormat; - private final DocValuesFormat perFieldDocValuesFormat; private final CompoundFormat compoundFormat; - private Codec lucene87Codec; - private PostingsFormat postingsFormat = null; public static final String KNN_87 = "KNN87Codec"; - public static final String LUCENE_87 = "Lucene87"; // Lucene Codec to be used + /** + * No arg constructor that uses Lucene87 as the delegate + */ public KNN87Codec() { - super(KNN_87); - // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 - // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); - this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { - @Override - public DocValuesFormat getDocValuesFormatForField(String field) { - return docValuesFormat; - } - }; - this.compoundFormat = new KNN80CompoundFormat(); + this(new Lucene87Codec()); } - /* - * This function returns the Codec. + /** + * Constructor that takes a Codec delegate to delegate all methods this code does not implement to. + * + * @param delegate codec that will perform all operations this codec does not override */ - public Codec getDelegatee() { - if (lucene87Codec == null) - lucene87Codec = Codec.forName(LUCENE_87); - return lucene87Codec; + public KNN87Codec(Codec delegate) { + super(KNN_87, delegate); + // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 + // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses + this.docValuesFormat = new KNN80DocValuesFormat(delegate.docValuesFormat()); + this.compoundFormat = new KNN80CompoundFormat(delegate.compoundFormat()); } @Override public DocValuesFormat docValuesFormat() { - return this.perFieldDocValuesFormat; - } - - /* - * For all the below functions, we could have extended FilterCodec, but this brings - * SPI related issues while loading Codec in the tests. So fall back to traditional - * approach of manually overriding. - */ - - - public void setPostingsFormat(PostingsFormat postingsFormat) { - this.postingsFormat = postingsFormat; - } - - @Override - public PostingsFormat postingsFormat() { - if (this.postingsFormat == null) { - return getDelegatee().postingsFormat(); - } - return this.postingsFormat; - } - - @Override - public StoredFieldsFormat storedFieldsFormat() { - return getDelegatee().storedFieldsFormat(); - } - - @Override - public TermVectorsFormat termVectorsFormat() { - return getDelegatee().termVectorsFormat(); - } - - @Override - public FieldInfosFormat fieldInfosFormat() { - return getDelegatee().fieldInfosFormat(); - } - - @Override - public SegmentInfoFormat segmentInfoFormat() { - return getDelegatee().segmentInfoFormat(); - } - - @Override - public NormsFormat normsFormat() { - return getDelegatee().normsFormat(); - } - - @Override - public LiveDocsFormat liveDocsFormat() { - return getDelegatee().liveDocsFormat(); + return this.docValuesFormat; } @Override public CompoundFormat compoundFormat() { return this.compoundFormat; } - - @Override - public PointsFormat pointsFormat() { - return getDelegatee().pointsFormat(); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java new file mode 100644 index 000000000..70991dfe8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.opensearch.index.codec.CodecServiceConfig; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.apache.lucene.codecs.Codec; +import org.opensearch.index.codec.CodecService; + +/** + * KNNCodecService to inject the right KNNCodec version + */ +public class KNNCodecService extends CodecService { + + public KNNCodecService(CodecServiceConfig codecServiceConfig) { + super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger()); + } + + /** + * Return the custom k-NN codec that wraps another codec that a user wants to use for non k-NN related operations + * + * @param name of delegate codec. + * @return Latest KNN Codec built with delegate codec. + */ + @Override + public Codec codec(String name) { + return new KNN87Codec(super.codec(name)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java b/src/main/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSub.java similarity index 95% rename from src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java rename to src/main/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSub.java index 67a476c8d..c47aa85d3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSub.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.codec; +package org.opensearch.knn.index.codec.util; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; @@ -19,10 +19,6 @@ public class BinaryDocValuesSub extends DocIDMerger.Sub { private final BinaryDocValues values; - public BinaryDocValues getValues() { - return values; - } - public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) { super(docMap); if (values == null || (values.docID() != -1)) { @@ -35,4 +31,8 @@ public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) { public int nextDoc() throws IOException { return values.nextDoc(); } -} \ No newline at end of file + + public BinaryDocValues getValues() { + return values; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java b/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java deleted file mode 100644 index c67f27923..000000000 --- a/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin; - -import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.PostingsFormat; -import org.opensearch.index.codec.CodecService; - -/** - * KNNCodecService to inject the right KNNCodec version - */ -class KNNCodecService extends CodecService { - - KNNCodecService() { - super(null, null); - } - - /** - * If the index is of type KNN i.e index.knn = true, We always - * return the KNN Codec - * - * @param name dummy name - * @return Latest KNN Codec - */ - @Override - public Codec codec(String name) { - Codec codec = Codec.forName(KNN87Codec.KNN_87); - if (codec == null) { - throw new IllegalArgumentException("failed to find codec [" + name + "]"); - } - return codec; - } - - public void setPostingsFormat(PostingsFormat postingsFormat) { - ((KNN87Codec)codec("")).setPostingsFormat(postingsFormat); - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java b/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java deleted file mode 100644 index c5da36864..000000000 --- a/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin; - -import org.opensearch.index.engine.Engine; -import org.opensearch.index.engine.EngineConfig; -import org.opensearch.index.engine.EngineFactory; -import org.opensearch.index.engine.InternalEngine; - -/** - * EngineFactory to inject the KNNCodecService to help segments write using the KNNCodec. - */ -class KNNEngineFactory implements EngineFactory { - - private static KNNCodecService codecService = new KNNCodecService(); - - @Override - public Engine newReadWriteEngine(EngineConfig config) { - codecService.setPostingsFormat(config.getCodec().postingsFormat()); - EngineConfig engineConfig = new EngineConfig(config.getShardId(), - config.getThreadPool(), config.getIndexSettings(), config.getWarmer(), config.getStore(), - config.getMergePolicy(), config.getAnalyzer(), config.getSimilarity(), codecService, - config.getEventListener(), config.getQueryCache(), config.getQueryCachingPolicy(), - config.getTranslogConfig(), config.getFlushMergesAfter(), config.getExternalRefreshListener(), - config.getInternalRefreshListener(), config.getIndexSort(), config.getCircuitBreakerService(), - config.getGlobalCheckpointSupplier(), config.retentionLeasesSupplier(), config.getPrimaryTermSupplier(), - config.getTombstoneDocSupplier()); - return new InternalEngine(engineConfig); - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 122b77929..9cf0696f2 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -5,12 +5,15 @@ package org.opensearch.knn.plugin; +import org.opensearch.index.codec.CodecServiceFactory; +import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorFieldMapper; import org.opensearch.knn.index.KNNWeight; +import org.opensearch.knn.index.codec.KNNCodecService; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; @@ -49,7 +52,6 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.index.IndexModule; import org.opensearch.index.IndexSettings; -import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; import org.opensearch.knn.plugin.stats.KNNStatsConfig; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; @@ -133,13 +135,14 @@ public class KNNPlugin extends Plugin implements MapperPlugin, SearchPlugin, Act public static final String KNN_BASE_URI = "/_plugins/_knn"; private KNNStats knnStats; - private ModelDao modelDao; private ClusterService clusterService; @Override public Map getMappers() { - return Collections.singletonMap(KNNVectorFieldMapper.CONTENT_TYPE, new KNNVectorFieldMapper.TypeParser( - ModelDao.OpenSearchKNNModelDao::getInstance)); + return Collections.singletonMap( + KNNVectorFieldMapper.CONTENT_TYPE, + new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance) + ); } @Override @@ -148,12 +151,19 @@ public List> getQueries() { } @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, ScriptService scriptService, - NamedXContentRegistry xContentRegistry, Environment environment, - NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier) { + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { this.clusterService = clusterService; // Initialize Native Memory loading strategies @@ -178,25 +188,35 @@ public List> getSettings() { return KNNSettings.state().getSettings(); } - public List getRestHandlers(Settings settings, - RestController restController, - ClusterSettings clusterSettings, - IndexScopedSettings indexScopedSettings, - SettingsFilter settingsFilter, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier nodesInCluster) { + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { RestKNNStatsHandler restKNNStatsHandler = new RestKNNStatsHandler(settings, restController, knnStats); - RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService, - indexNameExpressionResolver); + RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler( + settings, + restController, + clusterService, + indexNameExpressionResolver + ); RestGetModelHandler restGetModelHandler = new RestGetModelHandler(); RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler(); RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler(); return ImmutableList.of( - restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler, - restTrainModelHandler, restSearchModelHandler + restKNNStatsHandler, + restKNNWarmupHandler, + restGetModelHandler, + restDeleteModelHandler, + restTrainModelHandler, + restSearchModelHandler ); } @@ -206,24 +226,28 @@ public List getRestHandlers(Settings settings, @Override public List> getActions() { return Arrays.asList( - new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), - new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), - new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), - new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, - TrainingJobRouteDecisionInfoTransportAction.class), - new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), - new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), - new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), - new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), - new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), - new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) + new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), + new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), + new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), + new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class), + new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), + new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), + new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), + new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), + new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) ); } @Override public Optional getEngineFactory(IndexSettings indexSettings) { + return Optional.empty(); + } + + @Override + public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { if (indexSettings.getValue(KNNSettings.IS_KNN_INDEX_SETTING)) { - return Optional.of(new KNNEngineFactory()); + return Optional.of(KNNCodecService::new); } return Optional.empty(); } @@ -267,15 +291,6 @@ public ScriptEngine getScriptEngine(Settings settings, Collection> getExecutorBuilders(Settings settings) { - return ImmutableList.of( - new FixedExecutorBuilder( - settings, - TRAIN_THREAD_POOL, - 1, - 1, - KNN_THREAD_POOL_PREFIX, - false - ) - ); + return ImmutableList.of(new FixedExecutorBuilder(settings, TRAIN_THREAD_POOL, 1, 1, KNN_THREAD_POOL_PREFIX, false)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java new file mode 100644 index 000000000..620559867 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import com.google.common.collect.ImmutableList; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.MergeState; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; + +import java.io.IOException; + +public class KNN80BinaryDocValuesTests extends KNNTestCase { + + public void testDocId() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + assertEquals(-1, knn80BinaryDocValues.docID()); + } + + public void testNextDoc() throws IOException { + final int expectedDoc = 12; + + BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { + @Override + public int get(int docID) { + return expectedDoc; + } + }, new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f)); + + DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger); + assertEquals(expectedDoc, knn80BinaryDocValues.nextDoc()); + } + + public void testAdvance() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, () -> knn80BinaryDocValues.advance(0)); + } + + public void testAdvanceExact() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, () -> knn80BinaryDocValues.advanceExact(0)); + } + + public void testCost() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, knn80BinaryDocValues::cost); + } + + public void testBinaryValue() throws IOException { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f); + BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }, binaryDocValues); + + DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger); + knn80BinaryDocValues.nextDoc(); + assertEquals(binaryDocValues.binaryValue(), knn80BinaryDocValues.binaryValue()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java new file mode 100644 index 000000000..a2f18cdfb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundDirectory; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.common.util.set.Sets; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Set; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KNN80CompoundFormatTests extends KNNTestCase { + + private static Directory directory; + private static Codec codec; + + @BeforeClass + public static void setStaticVariables() { + directory = newFSDirectory(createTempDir()); + codec = new KNN87Codec(); + } + + @AfterClass + public static void closeStaticVariables() throws IOException { + directory.close(); + } + + public void testGetCompoundReader() throws IOException { + CompoundDirectory dir = mock(CompoundDirectory.class); + CompoundFormat delegate = mock(CompoundFormat.class); + when(delegate.getCompoundReader(null, null, null)).thenReturn(dir); + KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); + assertEquals(dir, knn80CompoundFormat.getCompoundReader(null, null, null)); + } + + public void testWrite() throws IOException { + // Check that all normal engine files correctly get set to compound extension files after write + String segmentName = "_test"; + + Set segmentFiles = Sets.newHashSet( + String.format("%s_nmslib1%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_nmslib2%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_nmslib3%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_faiss1%s", segmentName, KNNEngine.FAISS.getExtension()), + String.format("%s_faiss2%s", segmentName, KNNEngine.FAISS.getExtension()), + String.format("%s_faiss3%s", segmentName, KNNEngine.FAISS.getExtension()) + ); + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, segmentFiles.size(), codec).build(); + + for (String name : segmentFiles) { + IndexOutput indexOutput = directory.createOutput(name, IOContext.DEFAULT); + indexOutput.close(); + } + segmentInfo.setFiles(segmentFiles); + + CompoundFormat delegate = mock(CompoundFormat.class); + doNothing().when(delegate).write(directory, segmentInfo, IOContext.DEFAULT); + + KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); + knn80CompoundFormat.write(directory, segmentInfo, IOContext.DEFAULT); + + assertTrue(segmentInfo.files().isEmpty()); + + Arrays.stream(directory.listAll()).forEach(filename -> { + try { + directory.deleteFile(filename); + } catch (IOException e) { + fail(String.format("Failed to delete: %s", filename)); + } + }); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java new file mode 100644 index 000000000..b632b2f3f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -0,0 +1,412 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.DocValuesConsumer; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.KNNVectorFieldMapper; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.plugin.stats.KNNCounter; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertFileInCorrectLocation; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertLoadableByEngine; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertValidFooter; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.getRandomVectors; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.RandomVectorDocValuesProducer; + +public class KNN80DocValuesConsumerTests extends KNNTestCase { + + private static Directory directory; + private static Codec codec; + + @BeforeClass + public static void setStaticVariables() { + directory = newFSDirectory(createTempDir()); + codec = new KNN87Codec(); + } + + @AfterClass + public static void closeStaticVariables() throws IOException { + directory.close(); + } + + public void testAddBinaryField_withKNN() throws IOException { + // Confirm that addKNNBinaryField will get called if the k-NN parameter is true + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .build(); + DocValuesProducer docValuesProducer = mock(DocValuesProducer.class); + + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + + final boolean[] called = { false }; + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { + + @Override + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + called[0] = true; + } + }; + + knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); + + verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); + assertTrue(called[0]); + } + + public void testAddBinaryField_withoutKNN() throws IOException { + // Confirm that the KNN80DocValuesConsumer will just call delegate AddBinaryField when k-NN parameter is + // not set + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); + DocValuesProducer docValuesProducer = mock(DocValuesProducer.class); + + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + + final boolean[] called = { false }; + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { + + @Override + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + called[0] = true; + } + }; + + knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); + + verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); + assertFalse(called[0]); + } + + public void testAddKNNBinaryField_noVectors() throws IOException { + // When there are no new vectors, no more graph index requests should be added + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(0, 128); + Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); + knn80DocValuesConsumer.addKNNBinaryField(null, randomVectorDocValuesProducer); + assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); + } + + public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + + String parameterString = Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") + .addAttribute(KNNConstants.HNSW_ALGO_M, "16") + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.INNER_PRODUCT; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + + String parameterString = Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException { + // Generate a trained faiss model + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.INNER_PRODUCT; + int dimension = 16; + String modelId = "test-model-id"; + + float[][] trainingData = getRandomVectors(200, dimension); + long trainingPtr = JNIService.transferVectors(0, trainingData); + + Map parameters = ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + "IVF4,Flat", + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue() + ); + + byte[] modelBytes = JNIService.trainIndex(parameters, dimension, trainingPtr, knnEngine.getName()); + Model model = new Model( + new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, "timestamp", "Empty description", ""), + modelBytes, + modelId + ); + JNIService.freeVectors(trainingPtr); + + // Setup the model cache to return the correct model + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(Settings.EMPTY); + + ClusterSettings clusterSettings = new ClusterSettings( + Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10kb").build(), + ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING) + ); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ModelCache.initialize(modelDao, clusterService); + + // Build the segment and field info + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(MODEL_ID, modelId) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testMerge_exception() throws IOException { + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); + expectThrows(RuntimeException.class, () -> knn80DocValuesConsumer.merge(null)); + } + + public void testAddSortedSetField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedSetField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addSortedSetField(null, null); + verify(delegate, times(1)).addSortedSetField(null, null); + } + + public void testAddSortedNumericField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedNumericField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addSortedNumericField(null, null); + verify(delegate, times(1)).addSortedNumericField(null, null); + } + + public void testAddSortedField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addSortedField(null, null); + verify(delegate, times(1)).addSortedField(null, null); + } + + public void testAddNumericField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addNumericField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addNumericField(null, null); + verify(delegate, times(1)).addNumericField(null, null); + } + + public void testClose() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).close(); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.close(); + verify(delegate, times(1)).close(); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java index 968e1f4e7..b7f909ec1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java @@ -12,10 +12,6 @@ public class KNN87CodecTests extends KNNCodecTestCase { - public void testFooter() throws Exception { - testFooter(new KNN87Codec()); - } - public void testMultiFieldsKnnIndex() throws Exception { testMultiFieldsKnnIndex(new KNN87Codec()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 0688f554e..3454c7d52 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -20,20 +20,14 @@ import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomIndexWriter; -import org.apache.lucene.index.SegmentReader; import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; -import org.apache.lucene.store.IOContext; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -65,7 +59,7 @@ /** * Test used for testing Codecs */ -public class KNNCodecTestCase extends KNNTestCase { +public class KNNCodecTestCase extends KNNTestCase { private static FieldType sampleFieldType; static { @@ -86,54 +80,8 @@ protected void setUpMockClusterService() { } protected ResourceWatcherService createDisabledResourceWatcherService() { - final Settings settings = Settings.builder() - .put("resource.reload.enabled", false) - .build(); - return new ResourceWatcherService( - settings, - null - ); - } - - public void testFooter(Codec codec) throws Exception { - setUpMockClusterService(); - Directory dir = newFSDirectory(createTempDir()); - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergeScheduler(new SerialMergeScheduler()); - iwc.setCodec(codec); - - float[] array = {1.0f, 2.0f, 3.0f}; - VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); - RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); - Document doc = new Document(); - doc.add(vectorField); - writer.addDocument(doc); - - ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); - NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - IndexReader reader = writer.getReader(); - LeafReaderContext lrc = reader.getContext().leaves().iterator().next(); // leaf reader context - SegmentReader segmentReader = (SegmentReader) FilterLeafReader.unwrap(lrc.reader()); - String hnswFileExtension = segmentReader.getSegmentInfo().info.getUseCompoundFile() - ? KNNEngine.NMSLIB.getCompoundExtension() : KNNEngine.NMSLIB.getExtension(); - String hnswSuffix = "test_vector" + hnswFileExtension; - List hnswFiles = segmentReader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(hnswSuffix)) - .collect(Collectors.toList()); - assertTrue(!hnswFiles.isEmpty()); - ChecksumIndexInput indexInput = dir.openChecksumInput(hnswFiles.get(0), IOContext.DEFAULT); - indexInput.seek(indexInput.length() - CodecUtil.footerLength()); - CodecUtil.checkFooter(indexInput); // If footer is not valid, it would throw exception and test fails - indexInput.close(); - - IndexSearcher searcher = new IndexSearcher(reader); - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1, "myindex"))); - - reader.close(); - writer.close(); - dir.close(); - resourceWatcherService.close(); - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + final Settings settings = Settings.builder().put("resource.reload.enabled", false).build(); + return new ResourceWatcherService(settings, null); } public void testMultiFieldsKnnIndex(Codec codec) throws Exception { @@ -146,7 +94,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { /** * Add doc with field "test_vector" */ - float[] array = {1.0f, 3.0f, 4.0f}; + float[] array = { 1.0f, 3.0f, 4.0f }; VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); @@ -161,7 +109,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { iwc1.setMergeScheduler(new SerialMergeScheduler()); iwc1.setCodec(new KNN87Codec()); writer = new RandomIndexWriter(random(), dir, iwc1); - float[] array1 = {6.0f, 14.0f}; + float[] array1 = { 6.0f, 14.0f }; VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); Document doc1 = new Document(); doc1.add(vectorField1); @@ -179,14 +127,14 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { // query to verify distance for each of the field IndexSearcher searcher = new IndexSearcher(reader); - float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"), 10).scoreDocs[0].score; - float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1, "dummy"), 10).scoreDocs[0].score; - assertEquals(1.0f/(1 + 25), score, 0.01f); - assertEquals(1.0f/(1 + 169), score1, 0.01f); + float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score; + float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score; + assertEquals(1.0f / (1 + 25), score, 0.01f); + assertEquals(1.0f / (1 + 169), score1, 0.01f); // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy"))); reader.close(); dir.close(); @@ -203,25 +151,33 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio // "Train" a faiss flat index - this really just creates an empty index that does brute force k-NN long vectorsPointer = JNIService.transferVectors(0, new float[0][0]); - byte [] modelBlob = JNIService.trainIndex(ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "Flat", - SPACE_TYPE, spaceType.getValue()), dimension, vectorsPointer, - KNNEngine.FAISS.getName()); + byte[] modelBlob = JNIService.trainIndex( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "Flat", SPACE_TYPE, spaceType.getValue()), + dimension, + vectorsPointer, + KNNEngine.FAISS.getName() + ); // Setup model cache ModelDao modelDao = mock(ModelDao.class); // Set model state to created - ModelMetadata modelMetadata1 = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); when(modelDao.get(modelId)).thenReturn(mockModel); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata1); Settings settings = settings(CURRENT).put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getSettings()).thenReturn(settings); @@ -242,12 +198,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio fieldType.freeze(); // Add the documents to the index - float[][] arrays = { - {1.0f, 3.0f, 4.0f}, - {2.0f, 5.0f, 8.0f}, - {3.0f, 6.0f, 9.0f}, - {4.0f, 7.0f, 10.0f} - }; + float[][] arrays = { { 1.0f, 3.0f, 4.0f }, { 2.0f, 5.0f, 8.0f }, { 3.0f, 6.0f, 9.0f }, { 4.0f, 7.0f, 10.0f } }; RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); String fieldName = "test_vector"; @@ -265,7 +216,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio KNNWeight.initialize(modelDao); ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - float [] query = {10.0f, 10.0f, 10.0f}; + float[] query = { 10.0f, 10.0f, 10.0f }; IndexSearcher searcher = new IndexSearcher(reader); TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10); @@ -280,4 +231,3 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } } - diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java new file mode 100644 index 000000000..de335a115 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -0,0 +1,421 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.opensearch.common.collect.Set; +import org.opensearch.knn.index.KNNQueryResult; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; +import static org.junit.Assert.assertTrue; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.test.OpenSearchTestCase.randomByteArrayOfLength; + +public class KNNCodecTestUtil { + + // Utility class to help build SegmentInfo with reasonable defaults + public static class SegmentInfoBuilder { + + private final Directory directory; + private final String segmentName; + private final int docsInSegment; + private final Codec codec; + + private Version version; + private Version minVersion; + private boolean isCompoundFile; + private byte[] segmentId; + private final Map attributes; + private Sort indexSort; + + public static SegmentInfoBuilder builder(Directory directory, String segmentName, int docsInSegment, Codec codec) { + return new SegmentInfoBuilder(directory, segmentName, docsInSegment, codec); + } + + private SegmentInfoBuilder(Directory directory, String segmentName, int docsInSegment, Codec codec) { + this.directory = directory; + this.segmentName = segmentName; + this.docsInSegment = docsInSegment; + this.codec = codec; + + this.version = Version.LATEST; + this.minVersion = Version.LATEST; + this.isCompoundFile = false; + this.segmentId = randomByteArrayOfLength(StringHelper.ID_LENGTH); + this.attributes = new HashMap<>(); + this.indexSort = Sort.INDEXORDER; + } + + public SegmentInfoBuilder version(Version version) { + this.version = version; + return this; + } + + public SegmentInfoBuilder minVersion(Version minVersion) { + this.minVersion = minVersion; + return this; + } + + public SegmentInfoBuilder isCompoundFile(boolean isCompoundFile) { + this.isCompoundFile = isCompoundFile; + return this; + } + + public SegmentInfoBuilder segmentId(byte[] segmentId) { + this.segmentId = segmentId; + return this; + } + + public SegmentInfoBuilder addAttribute(String key, String value) { + this.attributes.put(key, value); + return this; + } + + public SegmentInfoBuilder indexSort(Sort indexSort) { + this.indexSort = indexSort; + return this; + } + + public SegmentInfo build() { + return new SegmentInfo( + directory, + version, + minVersion, + segmentName, + docsInSegment, + isCompoundFile, + codec, + Collections.emptyMap(), + segmentId, + attributes, + indexSort + ); + } + } + + // Utility class to help build FieldInfo + public static class FieldInfoBuilder { + private final String fieldName; + private int fieldNumber; + private boolean storeTermVector; + private boolean omitNorms; + private boolean storePayloads; + private IndexOptions indexOptions; + private DocValuesType docValuesType; + private long dvGen; + private final Map attributes; + private int pointDimensionCount; + private int pointIndexDimensionCount; + private int pointNumBytes; + private boolean softDeletes; + + public static FieldInfoBuilder builder(String fieldName) { + return new FieldInfoBuilder(fieldName); + } + + private FieldInfoBuilder(String fieldName) { + this.fieldName = fieldName; + this.fieldNumber = 0; + this.storeTermVector = false; + this.omitNorms = true; + this.storePayloads = true; + this.indexOptions = IndexOptions.NONE; + this.docValuesType = DocValuesType.BINARY; + this.dvGen = 0; + this.attributes = new HashMap<>(); + this.pointDimensionCount = 0; + this.pointIndexDimensionCount = 0; + this.pointNumBytes = 0; + this.softDeletes = false; + } + + public FieldInfoBuilder fieldNumber(int fieldNumber) { + this.fieldNumber = fieldNumber; + return this; + } + + public FieldInfoBuilder storeTermVector(boolean storeTermVector) { + this.storeTermVector = storeTermVector; + return this; + } + + public FieldInfoBuilder omitNorms(boolean omitNorms) { + this.omitNorms = omitNorms; + return this; + } + + public FieldInfoBuilder storePayloads(boolean storePayloads) { + this.storePayloads = storePayloads; + return this; + } + + public FieldInfoBuilder indexOptions(IndexOptions indexOptions) { + this.indexOptions = indexOptions; + return this; + } + + public FieldInfoBuilder docValuesType(DocValuesType docValuesType) { + this.docValuesType = docValuesType; + return this; + } + + public FieldInfoBuilder dvGen(long dvGen) { + this.dvGen = dvGen; + return this; + } + + public FieldInfoBuilder addAttribute(String key, String value) { + this.attributes.put(key, value); + return this; + } + + public FieldInfoBuilder pointDimensionCount(int pointDimensionCount) { + this.pointDimensionCount = pointDimensionCount; + return this; + } + + public FieldInfoBuilder pointIndexDimensionCount(int pointIndexDimensionCount) { + this.pointIndexDimensionCount = pointIndexDimensionCount; + return this; + } + + public FieldInfoBuilder pointNumBytes(int pointNumBytes) { + this.pointNumBytes = pointNumBytes; + return this; + } + + public FieldInfoBuilder softDeletes(boolean softDeletes) { + this.softDeletes = softDeletes; + return this; + } + + public FieldInfo build() { + return new FieldInfo( + fieldName, + fieldNumber, + storeTermVector, + omitNorms, + storePayloads, + indexOptions, + docValuesType, + dvGen, + attributes, + pointDimensionCount, + pointIndexDimensionCount, + pointNumBytes, + softDeletes + ); + } + } + + public static abstract class VectorDocValues extends BinaryDocValues { + + final int count; + final int dimension; + int current; + KNNVectorSerializer knnVectorSerializer; + + public VectorDocValues(int count, int dimension) { + this.count = count; + this.dimension = dimension; + this.current = -1; + this.knnVectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); + } + + @Override + public boolean advanceExact(int target) throws IOException { + return false; + } + + @Override + public int docID() { + if (this.current > this.count) { + return BinaryDocValues.NO_MORE_DOCS; + } + return this.current; + } + + @Override + public int nextDoc() throws IOException { + return advance(current + 1); + } + + @Override + public int advance(int target) throws IOException { + current = target; + if (current >= count) { + current = NO_MORE_DOCS; + } + return current; + } + + @Override + public long cost() { + return 0; + } + } + + public static class ConstantVectorBinaryDocValues extends VectorDocValues { + + private final BytesRef value; + + public ConstantVectorBinaryDocValues(int count, int dimension, float value) { + super(count, dimension); + float[] array = new float[dimension]; + Arrays.fill(array, value); + this.value = new BytesRef(knnVectorSerializer.floatToByteArray(array)); + } + + @Override + public BytesRef binaryValue() throws IOException { + return value; + } + } + + public static class RandomVectorBinaryDocValues extends VectorDocValues { + + public RandomVectorBinaryDocValues(int count, int dimension) { + super(count, dimension); + } + + @Override + public BytesRef binaryValue() throws IOException { + return new BytesRef(knnVectorSerializer.floatToByteArray(getRandomVector(dimension))); + } + } + + public static class RandomVectorDocValuesProducer extends DocValuesProducer { + + final RandomVectorBinaryDocValues randomBinaryDocValues; + + public RandomVectorDocValuesProducer(int count, int dimension) { + this.randomBinaryDocValues = new RandomVectorBinaryDocValues(count, dimension); + } + + @Override + public NumericDocValues getNumeric(FieldInfo field) { + return null; + } + + @Override + public BinaryDocValues getBinary(FieldInfo field) throws IOException { + return randomBinaryDocValues; + } + + @Override + public SortedDocValues getSorted(FieldInfo field) { + return null; + } + + @Override + public SortedNumericDocValues getSortedNumeric(FieldInfo field) { + return null; + } + + @Override + public SortedSetDocValues getSortedSet(FieldInfo field) { + return null; + } + + @Override + public void checkIntegrity() { + + } + + @Override + public void close() throws IOException { + + } + + @Override + public long ramBytesUsed() { + return 0; + } + } + + public static void assertFileInCorrectLocation(SegmentWriteState state, String expectedFile) throws IOException { + assertTrue(Set.of(state.directory.listAll()).contains(expectedFile)); + } + + public static void assertValidFooter(Directory dir, String filename) throws IOException { + ChecksumIndexInput indexInput = dir.openChecksumInput(filename, IOContext.DEFAULT); + indexInput.seek(indexInput.length() - CodecUtil.footerLength()); + CodecUtil.checkFooter(indexInput); + indexInput.close(); + } + + public static void assertLoadableByEngine( + SegmentWriteState state, + String fileName, + KNNEngine knnEngine, + SpaceType spaceType, + int dimension + ) { + String filePath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), fileName) + .toString(); + long indexPtr = JNIService.loadIndex( + filePath, + Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), + knnEngine.getName() + ); + int k = 2; + float[] queryVector = new float[dimension]; + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName()); + assertTrue(results.length > 0); + JNIService.free(indexPtr, knnEngine.getName()); + } + + public static float[][] getRandomVectors(int count, int dimension) { + float[][] data = new float[count][dimension]; + for (int i = 0; i < count; i++) { + data[i] = getRandomVector(dimension); + } + return data; + } + + public static float[] getRandomVector(int dimension) { + float[] data = new float[dimension]; + for (int i = 0; i < dimension; i++) { + data[i] = randomFloat(); + } + return data; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java b/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java new file mode 100644 index 000000000..a2105af3a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.codec.util; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.MergeState; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.io.IOException; + +public class BinaryDocValuesSubTests extends KNNTestCase { + + public void testNextDoc() throws IOException { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 2.0f); + MergeState.DocMap docMap = new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }; + + BinaryDocValuesSub binaryDocValuesSub = new BinaryDocValuesSub(docMap, binaryDocValues); + int expectedNextDoc = binaryDocValues.nextDoc() + 1; + assertEquals(expectedNextDoc, binaryDocValuesSub.nextDoc()); + } + + public void testGetValues() { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 2.0f); + MergeState.DocMap docMap = new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }; + + BinaryDocValuesSub binaryDocValuesSub = new BinaryDocValuesSub(docMap, binaryDocValues); + + assertEquals(binaryDocValues, binaryDocValuesSub.getValues()); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerTests.java similarity index 82% rename from src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java rename to src/test/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerTests.java index 4bc62bebf..1d08df6a0 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerTests.java @@ -3,12 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.codec; +package org.opensearch.knn.index.codec.util; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.codec.util.SerializationMode; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -22,8 +19,8 @@ public class KNNVectorSerializerTests extends KNNTestCase { Random random = new Random(); public void testVectorSerializerFactory() throws Exception { - //check that default serializer can work with array of floats - //setup + // check that default serializer can work with array of floats + // setup final float[] vector = getArrayOfRandomFloats(20); final ByteArrayOutputStream bas = new ByteArrayOutputStream(); final DataOutputStream ds = new DataOutputStream(bas); @@ -40,19 +37,18 @@ public void testVectorSerializerFactory() throws Exception { assertNotNull(actualDeserializedVector); assertArrayEquals(vector, actualDeserializedVector, 0.1f); - final KNNVectorSerializer arraySerializer = - KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY); + final KNNVectorSerializer arraySerializer = KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY); assertNotNull(arraySerializer); - final KNNVectorSerializer collectionOfFloatsSerializer = - KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.COLLECTION_OF_FLOATS); + final KNNVectorSerializer collectionOfFloatsSerializer = KNNVectorSerializerFactory.getSerializerBySerializationMode( + SerializationMode.COLLECTION_OF_FLOATS + ); assertNotNull(collectionOfFloatsSerializer); } - public void testVectorSerializerFactory_throwExceptionForStreamWithUnsupportedDataType() throws Exception { - //prepare array of chars that is not supported by serializer factory. expected behavior is to fail - final char[] arrayOfChars = new char[] {'a', 'b', 'c'}; + // prepare array of chars that is not supported by serializer factory. expected behavior is to fail + final char[] arrayOfChars = new char[] { 'a', 'b', 'c' }; final ByteArrayOutputStream bas = new ByteArrayOutputStream(); final DataOutputStream ds = new DataOutputStream(bas); for (char ch : arrayOfChars) @@ -75,14 +71,14 @@ public void testVectorAsArraySerializer() throws Exception { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(bais); - //testing serialization + // testing serialization bais.reset(); final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); assertNotNull(actualSerializedVector); assertArrayEquals(serializedVector, actualSerializedVector); - //testing deserialization + // testing deserialization bais.reset(); final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(bais); @@ -91,7 +87,7 @@ public void testVectorAsArraySerializer() throws Exception { } public void testVectorAsCollectionOfFloatsSerializer() throws Exception { - //setup + // setup final float[] vector = getArrayOfRandomFloats(20); final ByteArrayOutputStream bas = new ByteArrayOutputStream(); @@ -103,14 +99,14 @@ public void testVectorAsCollectionOfFloatsSerializer() throws Exception { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(bais); - //testing serialization + // testing serialization bais.reset(); final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); assertNotNull(actualSerializedVector); assertArrayEquals(vectorAsCollectionOfFloats, actualSerializedVector); - //testing deserialization + // testing deserialization bais.reset(); final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(bais);