From d531b3c54d581e9f9c2402a71c1d065c8cbc5f0b Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Thu, 17 Mar 2022 20:26:55 -0400 Subject: [PATCH] Refactor KNNCodec to use new extension point (#319) Refactor plugin to return CodecServiceFactory as opposed to CodecService. This will allow the plugin to make decisions based on Mapper Service. Refactors the KNN87Codec to implement FilterCodec. This allows the codec to automatically/flexibly delegate operations it does not override to an arbitrary Codec. Additionally cleans up some code around the Codec Adds unit tests that map to each codec component. Did not add tests for merging and codec utils. This can be undertaken later. Adds a utils folder for sharing testing functionality between codec tests. Cleans up a few minor details around codec source code. Signed-off-by: John Mazanec --- gradle.properties | 6 + .../KNN80Codec/KNN80BinaryDocValues.java | 2 +- .../codec/KNN80Codec/KNN80CompoundFormat.java | 26 +- .../KNN80Codec/KNN80DocValuesConsumer.java | 179 ++++---- .../KNN80Codec/KNN80DocValuesFormat.java | 18 +- .../KNN80Codec/KNN80DocValuesReader.java | 4 +- .../index/codec/KNN87Codec/KNN87Codec.java | 107 +---- .../knn/index/codec/KNNCodecService.java | 32 ++ .../codec/{ => util}/BinaryDocValuesSub.java | 12 +- .../knn/plugin/KNNCodecService.java | 41 -- .../knn/plugin/KNNEngineFactory.java | 33 -- .../org/opensearch/knn/plugin/KNNPlugin.java | 101 +++-- .../KNN80Codec/KNN80BinaryDocValuesTests.java | 69 +++ .../KNN80Codec/KNN80CompoundFormatTests.java | 93 ++++ .../KNN80DocValuesConsumerTests.java | 412 +++++++++++++++++ .../codec/KNN87Codec/KNN87CodecTests.java | 4 - .../knn/index/codec/KNNCodecTestCase.java | 108 ++--- .../knn/index/codec/KNNCodecTestUtil.java | 421 ++++++++++++++++++ .../codec/util/BinaryDocValuesSubTests.java | 44 ++ .../{ => util}/KNNVectorSerializerTests.java | 32 +- 20 files changed, 1327 insertions(+), 417 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java rename src/main/java/org/opensearch/knn/index/codec/{ => util}/BinaryDocValuesSub.java (95%) delete mode 100644 src/main/java/org/opensearch/knn/plugin/KNNCodecService.java delete mode 100644 src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java rename src/test/java/org/opensearch/knn/index/codec/{ => util}/KNNVectorSerializerTests.java (82%) diff --git a/gradle.properties b/gradle.properties index 8c7ff252a..86e266f28 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,3 +4,9 @@ # version=1.0.0 + +org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java index 827f0a3d1..832737a6d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.opensearch.knn.index.codec.BinaryDocValuesSub; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.util.BytesRef; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index ca581b6ae..8a8ed558e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -5,10 +5,8 @@ package org.opensearch.knn.index.codec.KNN80Codec; +import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat; import org.opensearch.knn.common.KNNConstants; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CompoundDirectory; import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.index.SegmentInfo; @@ -26,14 +24,24 @@ */ public class KNN80CompoundFormat extends CompoundFormat { - private final Logger logger = LogManager.getLogger(KNN80CompoundFormat.class); + private final CompoundFormat delegate; public KNN80CompoundFormat() { + this.delegate = new Lucene50CompoundFormat(); + } + + /** + * Constructor that takes a delegate to handle non-overridden methods + * + * @param delegate CompoundFormat that will handle non-overridden methods + */ + public KNN80CompoundFormat(CompoundFormat delegate) { + this.delegate = delegate; } @Override public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { - return Codec.getDefault().compoundFormat().getCompoundReader(dir, si, context); + return delegate.getCompoundReader(dir, si, context); } @Override @@ -41,17 +49,15 @@ public void write(Directory dir, SegmentInfo si, IOContext context) throws IOExc for (KNNEngine knnEngine : KNNEngine.values()) { writeEngineFiles(dir, si, context, knnEngine.getExtension()); } - Codec.getDefault().compoundFormat().write(dir, si, context); + delegate.write(dir, si, context); } - private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) - throws IOException { + private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException { /* * If engine file present, remove it from the compounding file list to avoid header/footer checks * and create a new compounding file format with extension engine + c. */ - Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)) - .collect(Collectors.toSet()); + Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet()); Set segmentFiles = new HashSet<>(si.files()); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index d4bc662c5..0bbd6e9ac 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -67,108 +67,117 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { @Override public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { delegatee.addBinaryField(field, valuesProducer); - addKNNBinaryField(field, valuesProducer); + if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { + addKNNBinaryField(field, valuesProducer); + } } public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { - // Get values to be indexed - BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.vectors.length == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); - return; - } + // Get values to be indexed + BinaryDocValues values = valuesProducer.getBinary(field); + KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); + if (pair.vectors.length == 0 || pair.docs.length == 0) { + logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); + return; + } - // Create library index either from model or from scratch - String engineFileName; - String indexPath; - String tmpEngineFileName; + // Increment counter for number of graph index requests + KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - if (field.attributes().containsKey(MODEL_ID)) { + // Create library index either from model or from scratch + String engineFileName; + String indexPath; + String tmpEngineFileName; - String modelId = field.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); + if (field.attributes().containsKey(MODEL_ID)) { - KNNEngine knnEngine = model.getModelMetadata().getKnnEngine(); + String modelId = field.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(), - field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName).toString(); - tmpEngineFileName = engineFileName + TEMP_SUFFIX; - String tempIndexPath = indexPath + TEMP_SUFFIX; + KNNEngine knnEngine = model.getModelMetadata().getKnnEngine(); - if (model.getModelBlob() == null) { - throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); - } + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); + tmpEngineFileName = engineFileName + TEMP_SUFFIX; + String tempIndexPath = indexPath + TEMP_SUFFIX; - createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath); - } else { + if (model.getModelBlob() == null) { + throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); + } - // Get engine to be used for indexing - String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); - KNNEngine knnEngine = KNNEngine.getEngine(engineName); + createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath); + } else { - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(), - field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName).toString(); - tmpEngineFileName = engineFileName + TEMP_SUFFIX; - String tempIndexPath = indexPath + TEMP_SUFFIX; + // Get engine to be used for indexing + String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + KNNEngine knnEngine = KNNEngine.getEngine(engineName); - createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath); - } + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); + tmpEngineFileName = engineFileName + TEMP_SUFFIX; + String tempIndexPath = indexPath + TEMP_SUFFIX; - /* - * Adds Footer to the serialized graph - * 1. Copies the serialized graph to new file. - * 2. Adds Footer to the new file. - * - * We had to create new file here because adding footer directly to the - * existing file will miss calculating checksum for the serialized graph - * bytes and result in index corruption issues. - */ - //TODO: I think this can be refactored to avoid this copy and then write - // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 - try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); - IndexOutput os = state.directory.createOutput(engineFileName, state.context)) { - os.copyBytes(is, is.length()); - CodecUtil.writeFooter(os); - } catch (Exception ex) { - KNNCounter.GRAPH_INDEX_ERRORS.increment(); - throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex); - } finally { - IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName); - } + createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath); + } + + /* + * Adds Footer to the serialized graph + * 1. Copies the serialized graph to new file. + * 2. Adds Footer to the new file. + * + * We had to create new file here because adding footer directly to the + * existing file will miss calculating checksum for the serialized graph + * bytes and result in index corruption issues. + */ + // TODO: I think this can be refactored to avoid this copy and then write + // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 + try ( + IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); + IndexOutput os = state.directory.createOutput(engineFileName, state.context) + ) { + os.copyBytes(is, is.length()); + CodecUtil.writeFooter(os); + } catch (Exception ex) { + KNNCounter.GRAPH_INDEX_ERRORS.increment(); + throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex); + } finally { + IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName); } } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, - String indexPath) { - Map parameters = ImmutableMap.of(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - AccessController.doPrivileged( - (PrivilegedAction) () -> { - JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, - knnEngine.getName()); - return null; - } + private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = ImmutableMap.of( + KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName()); + return null; + }); } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, - String indexPath) throws IOException { + private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) + throws IOException { Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); // parametersString will be null when legacy mapper is used if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, - SpaceType.DEFAULT.getValue())); + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); Map algoParams = new HashMap<>(); @@ -183,22 +192,20 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa parameters.put(PARAMETERS, algoParams); } else { parameters.putAll( - XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map() + XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString) + .map() ); } // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); // Pass the path for the nms library to save the file - AccessController.doPrivileged( - (PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); - return null; - } - ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); + return null; + }); } /** @@ -214,7 +221,7 @@ public void merge(MergeState mergeState) { assert mergeState.mergeFieldInfos != null; for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) { DocValuesType type = fieldInfo.getDocValuesType(); - if (type == DocValuesType.BINARY) { + if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState)); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java index 9114f38ee..cd8362cf2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java @@ -5,11 +5,10 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -19,11 +18,20 @@ * Encodes/Decodes per document values */ public class KNN80DocValuesFormat extends DocValuesFormat { - private final Logger logger = LogManager.getLogger(KNN80DocValuesFormat.class); - private final DocValuesFormat delegate = DocValuesFormat.forName(KNN80Codec.LUCENE_80); + private final DocValuesFormat delegate; public KNN80DocValuesFormat() { - super(KNN80Codec.LUCENE_80); + this(new Lucene80DocValuesFormat()); + } + + /** + * Constructor that takes delegate in order to handle non-overridden methods + * + * @param delegate DocValuesFormat to handle non-overridden methods + */ + public KNN80DocValuesFormat(DocValuesFormat delegate) { + super(delegate.getName()); + this.delegate = delegate; } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java index 943b39abc..ccfaa68fc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.opensearch.knn.index.codec.BinaryDocValuesSub; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; @@ -22,7 +22,7 @@ */ class KNN80DocValuesReader extends EmptyDocValuesProducer { - private MergeState mergeState; + private final MergeState mergeState; KNN80DocValuesReader(MergeState mergeState) { this.mergeState = mergeState; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java index 6e2e897e0..6ec5ec05c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java @@ -5,124 +5,53 @@ package org.opensearch.knn.index.codec.KNN87Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.lucene87.Lucene87Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; -import org.apache.lucene.codecs.FieldInfosFormat; -import org.apache.lucene.codecs.LiveDocsFormat; -import org.apache.lucene.codecs.NormsFormat; -import org.apache.lucene.codecs.PointsFormat; -import org.apache.lucene.codecs.PostingsFormat; -import org.apache.lucene.codecs.SegmentInfoFormat; -import org.apache.lucene.codecs.StoredFieldsFormat; -import org.apache.lucene.codecs.TermVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; /** * Extends the Codec to support a new file format for KNN index * based on the mappings. * */ -public final class KNN87Codec extends Codec { +public final class KNN87Codec extends FilterCodec { - private static final Logger logger = LogManager.getLogger(KNN87Codec.class); private final DocValuesFormat docValuesFormat; - private final DocValuesFormat perFieldDocValuesFormat; private final CompoundFormat compoundFormat; - private Codec lucene87Codec; - private PostingsFormat postingsFormat = null; public static final String KNN_87 = "KNN87Codec"; - public static final String LUCENE_87 = "Lucene87"; // Lucene Codec to be used + /** + * No arg constructor that uses Lucene87 as the delegate + */ public KNN87Codec() { - super(KNN_87); - // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 - // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); - this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { - @Override - public DocValuesFormat getDocValuesFormatForField(String field) { - return docValuesFormat; - } - }; - this.compoundFormat = new KNN80CompoundFormat(); + this(new Lucene87Codec()); } - /* - * This function returns the Codec. + /** + * Constructor that takes a Codec delegate to delegate all methods this code does not implement to. + * + * @param delegate codec that will perform all operations this codec does not override */ - public Codec getDelegatee() { - if (lucene87Codec == null) - lucene87Codec = Codec.forName(LUCENE_87); - return lucene87Codec; + public KNN87Codec(Codec delegate) { + super(KNN_87, delegate); + // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 + // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses + this.docValuesFormat = new KNN80DocValuesFormat(delegate.docValuesFormat()); + this.compoundFormat = new KNN80CompoundFormat(delegate.compoundFormat()); } @Override public DocValuesFormat docValuesFormat() { - return this.perFieldDocValuesFormat; - } - - /* - * For all the below functions, we could have extended FilterCodec, but this brings - * SPI related issues while loading Codec in the tests. So fall back to traditional - * approach of manually overriding. - */ - - - public void setPostingsFormat(PostingsFormat postingsFormat) { - this.postingsFormat = postingsFormat; - } - - @Override - public PostingsFormat postingsFormat() { - if (this.postingsFormat == null) { - return getDelegatee().postingsFormat(); - } - return this.postingsFormat; - } - - @Override - public StoredFieldsFormat storedFieldsFormat() { - return getDelegatee().storedFieldsFormat(); - } - - @Override - public TermVectorsFormat termVectorsFormat() { - return getDelegatee().termVectorsFormat(); - } - - @Override - public FieldInfosFormat fieldInfosFormat() { - return getDelegatee().fieldInfosFormat(); - } - - @Override - public SegmentInfoFormat segmentInfoFormat() { - return getDelegatee().segmentInfoFormat(); - } - - @Override - public NormsFormat normsFormat() { - return getDelegatee().normsFormat(); - } - - @Override - public LiveDocsFormat liveDocsFormat() { - return getDelegatee().liveDocsFormat(); + return this.docValuesFormat; } @Override public CompoundFormat compoundFormat() { return this.compoundFormat; } - - @Override - public PointsFormat pointsFormat() { - return getDelegatee().pointsFormat(); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java new file mode 100644 index 000000000..70991dfe8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.opensearch.index.codec.CodecServiceConfig; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.apache.lucene.codecs.Codec; +import org.opensearch.index.codec.CodecService; + +/** + * KNNCodecService to inject the right KNNCodec version + */ +public class KNNCodecService extends CodecService { + + public KNNCodecService(CodecServiceConfig codecServiceConfig) { + super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger()); + } + + /** + * Return the custom k-NN codec that wraps another codec that a user wants to use for non k-NN related operations + * + * @param name of delegate codec. + * @return Latest KNN Codec built with delegate codec. + */ + @Override + public Codec codec(String name) { + return new KNN87Codec(super.codec(name)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java b/src/main/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSub.java similarity index 95% rename from src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java rename to src/main/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSub.java index 67a476c8d..c47aa85d3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSub.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.codec; +package org.opensearch.knn.index.codec.util; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; @@ -19,10 +19,6 @@ public class BinaryDocValuesSub extends DocIDMerger.Sub { private final BinaryDocValues values; - public BinaryDocValues getValues() { - return values; - } - public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) { super(docMap); if (values == null || (values.docID() != -1)) { @@ -35,4 +31,8 @@ public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) { public int nextDoc() throws IOException { return values.nextDoc(); } -} \ No newline at end of file + + public BinaryDocValues getValues() { + return values; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java b/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java deleted file mode 100644 index c67f27923..000000000 --- a/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin; - -import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.PostingsFormat; -import org.opensearch.index.codec.CodecService; - -/** - * KNNCodecService to inject the right KNNCodec version - */ -class KNNCodecService extends CodecService { - - KNNCodecService() { - super(null, null); - } - - /** - * If the index is of type KNN i.e index.knn = true, We always - * return the KNN Codec - * - * @param name dummy name - * @return Latest KNN Codec - */ - @Override - public Codec codec(String name) { - Codec codec = Codec.forName(KNN87Codec.KNN_87); - if (codec == null) { - throw new IllegalArgumentException("failed to find codec [" + name + "]"); - } - return codec; - } - - public void setPostingsFormat(PostingsFormat postingsFormat) { - ((KNN87Codec)codec("")).setPostingsFormat(postingsFormat); - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java b/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java deleted file mode 100644 index c5da36864..000000000 --- a/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin; - -import org.opensearch.index.engine.Engine; -import org.opensearch.index.engine.EngineConfig; -import org.opensearch.index.engine.EngineFactory; -import org.opensearch.index.engine.InternalEngine; - -/** - * EngineFactory to inject the KNNCodecService to help segments write using the KNNCodec. - */ -class KNNEngineFactory implements EngineFactory { - - private static KNNCodecService codecService = new KNNCodecService(); - - @Override - public Engine newReadWriteEngine(EngineConfig config) { - codecService.setPostingsFormat(config.getCodec().postingsFormat()); - EngineConfig engineConfig = new EngineConfig(config.getShardId(), - config.getThreadPool(), config.getIndexSettings(), config.getWarmer(), config.getStore(), - config.getMergePolicy(), config.getAnalyzer(), config.getSimilarity(), codecService, - config.getEventListener(), config.getQueryCache(), config.getQueryCachingPolicy(), - config.getTranslogConfig(), config.getFlushMergesAfter(), config.getExternalRefreshListener(), - config.getInternalRefreshListener(), config.getIndexSort(), config.getCircuitBreakerService(), - config.getGlobalCheckpointSupplier(), config.retentionLeasesSupplier(), config.getPrimaryTermSupplier(), - config.getTombstoneDocSupplier()); - return new InternalEngine(engineConfig); - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 122b77929..9cf0696f2 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -5,12 +5,15 @@ package org.opensearch.knn.plugin; +import org.opensearch.index.codec.CodecServiceFactory; +import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorFieldMapper; import org.opensearch.knn.index.KNNWeight; +import org.opensearch.knn.index.codec.KNNCodecService; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; @@ -49,7 +52,6 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.index.IndexModule; import org.opensearch.index.IndexSettings; -import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; import org.opensearch.knn.plugin.stats.KNNStatsConfig; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; @@ -133,13 +135,14 @@ public class KNNPlugin extends Plugin implements MapperPlugin, SearchPlugin, Act public static final String KNN_BASE_URI = "/_plugins/_knn"; private KNNStats knnStats; - private ModelDao modelDao; private ClusterService clusterService; @Override public Map getMappers() { - return Collections.singletonMap(KNNVectorFieldMapper.CONTENT_TYPE, new KNNVectorFieldMapper.TypeParser( - ModelDao.OpenSearchKNNModelDao::getInstance)); + return Collections.singletonMap( + KNNVectorFieldMapper.CONTENT_TYPE, + new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance) + ); } @Override @@ -148,12 +151,19 @@ public List> getQueries() { } @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, ScriptService scriptService, - NamedXContentRegistry xContentRegistry, Environment environment, - NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier) { + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { this.clusterService = clusterService; // Initialize Native Memory loading strategies @@ -178,25 +188,35 @@ public List> getSettings() { return KNNSettings.state().getSettings(); } - public List getRestHandlers(Settings settings, - RestController restController, - ClusterSettings clusterSettings, - IndexScopedSettings indexScopedSettings, - SettingsFilter settingsFilter, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier nodesInCluster) { + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { RestKNNStatsHandler restKNNStatsHandler = new RestKNNStatsHandler(settings, restController, knnStats); - RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService, - indexNameExpressionResolver); + RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler( + settings, + restController, + clusterService, + indexNameExpressionResolver + ); RestGetModelHandler restGetModelHandler = new RestGetModelHandler(); RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler(); RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler(); return ImmutableList.of( - restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler, - restTrainModelHandler, restSearchModelHandler + restKNNStatsHandler, + restKNNWarmupHandler, + restGetModelHandler, + restDeleteModelHandler, + restTrainModelHandler, + restSearchModelHandler ); } @@ -206,24 +226,28 @@ public List getRestHandlers(Settings settings, @Override public List> getActions() { return Arrays.asList( - new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), - new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), - new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), - new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, - TrainingJobRouteDecisionInfoTransportAction.class), - new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), - new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), - new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), - new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), - new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), - new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) + new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), + new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), + new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), + new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class), + new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), + new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), + new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), + new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), + new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) ); } @Override public Optional getEngineFactory(IndexSettings indexSettings) { + return Optional.empty(); + } + + @Override + public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { if (indexSettings.getValue(KNNSettings.IS_KNN_INDEX_SETTING)) { - return Optional.of(new KNNEngineFactory()); + return Optional.of(KNNCodecService::new); } return Optional.empty(); } @@ -267,15 +291,6 @@ public ScriptEngine getScriptEngine(Settings settings, Collection> getExecutorBuilders(Settings settings) { - return ImmutableList.of( - new FixedExecutorBuilder( - settings, - TRAIN_THREAD_POOL, - 1, - 1, - KNN_THREAD_POOL_PREFIX, - false - ) - ); + return ImmutableList.of(new FixedExecutorBuilder(settings, TRAIN_THREAD_POOL, 1, 1, KNN_THREAD_POOL_PREFIX, false)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java new file mode 100644 index 000000000..620559867 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import com.google.common.collect.ImmutableList; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.MergeState; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; + +import java.io.IOException; + +public class KNN80BinaryDocValuesTests extends KNNTestCase { + + public void testDocId() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + assertEquals(-1, knn80BinaryDocValues.docID()); + } + + public void testNextDoc() throws IOException { + final int expectedDoc = 12; + + BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { + @Override + public int get(int docID) { + return expectedDoc; + } + }, new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f)); + + DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger); + assertEquals(expectedDoc, knn80BinaryDocValues.nextDoc()); + } + + public void testAdvance() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, () -> knn80BinaryDocValues.advance(0)); + } + + public void testAdvanceExact() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, () -> knn80BinaryDocValues.advanceExact(0)); + } + + public void testCost() { + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null); + expectThrows(UnsupportedOperationException.class, knn80BinaryDocValues::cost); + } + + public void testBinaryValue() throws IOException { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f); + BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }, binaryDocValues); + + DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); + KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger); + knn80BinaryDocValues.nextDoc(); + assertEquals(binaryDocValues.binaryValue(), knn80BinaryDocValues.binaryValue()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java new file mode 100644 index 000000000..a2f18cdfb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundDirectory; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.common.util.set.Sets; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Set; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KNN80CompoundFormatTests extends KNNTestCase { + + private static Directory directory; + private static Codec codec; + + @BeforeClass + public static void setStaticVariables() { + directory = newFSDirectory(createTempDir()); + codec = new KNN87Codec(); + } + + @AfterClass + public static void closeStaticVariables() throws IOException { + directory.close(); + } + + public void testGetCompoundReader() throws IOException { + CompoundDirectory dir = mock(CompoundDirectory.class); + CompoundFormat delegate = mock(CompoundFormat.class); + when(delegate.getCompoundReader(null, null, null)).thenReturn(dir); + KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); + assertEquals(dir, knn80CompoundFormat.getCompoundReader(null, null, null)); + } + + public void testWrite() throws IOException { + // Check that all normal engine files correctly get set to compound extension files after write + String segmentName = "_test"; + + Set segmentFiles = Sets.newHashSet( + String.format("%s_nmslib1%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_nmslib2%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_nmslib3%s", segmentName, KNNEngine.NMSLIB.getExtension()), + String.format("%s_faiss1%s", segmentName, KNNEngine.FAISS.getExtension()), + String.format("%s_faiss2%s", segmentName, KNNEngine.FAISS.getExtension()), + String.format("%s_faiss3%s", segmentName, KNNEngine.FAISS.getExtension()) + ); + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, segmentFiles.size(), codec).build(); + + for (String name : segmentFiles) { + IndexOutput indexOutput = directory.createOutput(name, IOContext.DEFAULT); + indexOutput.close(); + } + segmentInfo.setFiles(segmentFiles); + + CompoundFormat delegate = mock(CompoundFormat.class); + doNothing().when(delegate).write(directory, segmentInfo, IOContext.DEFAULT); + + KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); + knn80CompoundFormat.write(directory, segmentInfo, IOContext.DEFAULT); + + assertTrue(segmentInfo.files().isEmpty()); + + Arrays.stream(directory.listAll()).forEach(filename -> { + try { + directory.deleteFile(filename); + } catch (IOException e) { + fail(String.format("Failed to delete: %s", filename)); + } + }); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java new file mode 100644 index 000000000..b632b2f3f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -0,0 +1,412 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.DocValuesConsumer; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.KNNVectorFieldMapper; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.plugin.stats.KNNCounter; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertFileInCorrectLocation; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertLoadableByEngine; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertValidFooter; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.getRandomVectors; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.RandomVectorDocValuesProducer; + +public class KNN80DocValuesConsumerTests extends KNNTestCase { + + private static Directory directory; + private static Codec codec; + + @BeforeClass + public static void setStaticVariables() { + directory = newFSDirectory(createTempDir()); + codec = new KNN87Codec(); + } + + @AfterClass + public static void closeStaticVariables() throws IOException { + directory.close(); + } + + public void testAddBinaryField_withKNN() throws IOException { + // Confirm that addKNNBinaryField will get called if the k-NN parameter is true + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .build(); + DocValuesProducer docValuesProducer = mock(DocValuesProducer.class); + + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + + final boolean[] called = { false }; + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { + + @Override + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + called[0] = true; + } + }; + + knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); + + verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); + assertTrue(called[0]); + } + + public void testAddBinaryField_withoutKNN() throws IOException { + // Confirm that the KNN80DocValuesConsumer will just call delegate AddBinaryField when k-NN parameter is + // not set + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); + DocValuesProducer docValuesProducer = mock(DocValuesProducer.class); + + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + + final boolean[] called = { false }; + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { + + @Override + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + called[0] = true; + } + }; + + knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); + + verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); + assertFalse(called[0]); + } + + public void testAddKNNBinaryField_noVectors() throws IOException { + // When there are no new vectors, no more graph index requests should be added + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(0, 128); + Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); + knn80DocValuesConsumer.addKNNBinaryField(null, randomVectorDocValuesProducer); + assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); + } + + public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + + String parameterString = Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") + .addAttribute(KNNConstants.HNSW_ALGO_M, "16") + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.INNER_PRODUCT; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + + String parameterString = Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException { + // Generate a trained faiss model + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.INNER_PRODUCT; + int dimension = 16; + String modelId = "test-model-id"; + + float[][] trainingData = getRandomVectors(200, dimension); + long trainingPtr = JNIService.transferVectors(0, trainingData); + + Map parameters = ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + "IVF4,Flat", + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue() + ); + + byte[] modelBytes = JNIService.trainIndex(parameters, dimension, trainingPtr, knnEngine.getName()); + Model model = new Model( + new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, "timestamp", "Empty description", ""), + modelBytes, + modelId + ); + JNIService.freeVectors(trainingPtr); + + // Setup the model cache to return the correct model + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(Settings.EMPTY); + + ClusterSettings clusterSettings = new ClusterSettings( + Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10kb").build(), + ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING) + ); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ModelCache.initialize(modelDao, clusterService); + + // Build the segment and field info + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, docsInSegment, codec).build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(MODEL_ID, modelId) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getLatestBuildVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + } + + public void testMerge_exception() throws IOException { + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); + expectThrows(RuntimeException.class, () -> knn80DocValuesConsumer.merge(null)); + } + + public void testAddSortedSetField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedSetField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addSortedSetField(null, null); + verify(delegate, times(1)).addSortedSetField(null, null); + } + + public void testAddSortedNumericField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedNumericField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addSortedNumericField(null, null); + verify(delegate, times(1)).addSortedNumericField(null, null); + } + + public void testAddSortedField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addSortedField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addSortedField(null, null); + verify(delegate, times(1)).addSortedField(null, null); + } + + public void testAddNumericField() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).addNumericField(null, null); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.addNumericField(null, null); + verify(delegate, times(1)).addNumericField(null, null); + } + + public void testClose() throws IOException { + // Verify that the delegate will be called + DocValuesConsumer delegate = mock(DocValuesConsumer.class); + doNothing().when(delegate).close(); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null); + knn80DocValuesConsumer.close(); + verify(delegate, times(1)).close(); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java index 968e1f4e7..b7f909ec1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java @@ -12,10 +12,6 @@ public class KNN87CodecTests extends KNNCodecTestCase { - public void testFooter() throws Exception { - testFooter(new KNN87Codec()); - } - public void testMultiFieldsKnnIndex() throws Exception { testMultiFieldsKnnIndex(new KNN87Codec()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 0688f554e..3454c7d52 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -20,20 +20,14 @@ import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomIndexWriter; -import org.apache.lucene.index.SegmentReader; import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; -import org.apache.lucene.store.IOContext; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -65,7 +59,7 @@ /** * Test used for testing Codecs */ -public class KNNCodecTestCase extends KNNTestCase { +public class KNNCodecTestCase extends KNNTestCase { private static FieldType sampleFieldType; static { @@ -86,54 +80,8 @@ protected void setUpMockClusterService() { } protected ResourceWatcherService createDisabledResourceWatcherService() { - final Settings settings = Settings.builder() - .put("resource.reload.enabled", false) - .build(); - return new ResourceWatcherService( - settings, - null - ); - } - - public void testFooter(Codec codec) throws Exception { - setUpMockClusterService(); - Directory dir = newFSDirectory(createTempDir()); - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergeScheduler(new SerialMergeScheduler()); - iwc.setCodec(codec); - - float[] array = {1.0f, 2.0f, 3.0f}; - VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); - RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); - Document doc = new Document(); - doc.add(vectorField); - writer.addDocument(doc); - - ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); - NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - IndexReader reader = writer.getReader(); - LeafReaderContext lrc = reader.getContext().leaves().iterator().next(); // leaf reader context - SegmentReader segmentReader = (SegmentReader) FilterLeafReader.unwrap(lrc.reader()); - String hnswFileExtension = segmentReader.getSegmentInfo().info.getUseCompoundFile() - ? KNNEngine.NMSLIB.getCompoundExtension() : KNNEngine.NMSLIB.getExtension(); - String hnswSuffix = "test_vector" + hnswFileExtension; - List hnswFiles = segmentReader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(hnswSuffix)) - .collect(Collectors.toList()); - assertTrue(!hnswFiles.isEmpty()); - ChecksumIndexInput indexInput = dir.openChecksumInput(hnswFiles.get(0), IOContext.DEFAULT); - indexInput.seek(indexInput.length() - CodecUtil.footerLength()); - CodecUtil.checkFooter(indexInput); // If footer is not valid, it would throw exception and test fails - indexInput.close(); - - IndexSearcher searcher = new IndexSearcher(reader); - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1, "myindex"))); - - reader.close(); - writer.close(); - dir.close(); - resourceWatcherService.close(); - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + final Settings settings = Settings.builder().put("resource.reload.enabled", false).build(); + return new ResourceWatcherService(settings, null); } public void testMultiFieldsKnnIndex(Codec codec) throws Exception { @@ -146,7 +94,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { /** * Add doc with field "test_vector" */ - float[] array = {1.0f, 3.0f, 4.0f}; + float[] array = { 1.0f, 3.0f, 4.0f }; VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); @@ -161,7 +109,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { iwc1.setMergeScheduler(new SerialMergeScheduler()); iwc1.setCodec(new KNN87Codec()); writer = new RandomIndexWriter(random(), dir, iwc1); - float[] array1 = {6.0f, 14.0f}; + float[] array1 = { 6.0f, 14.0f }; VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); Document doc1 = new Document(); doc1.add(vectorField1); @@ -179,14 +127,14 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { // query to verify distance for each of the field IndexSearcher searcher = new IndexSearcher(reader); - float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"), 10).scoreDocs[0].score; - float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1, "dummy"), 10).scoreDocs[0].score; - assertEquals(1.0f/(1 + 25), score, 0.01f); - assertEquals(1.0f/(1 + 169), score1, 0.01f); + float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score; + float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score; + assertEquals(1.0f / (1 + 25), score, 0.01f); + assertEquals(1.0f / (1 + 169), score1, 0.01f); // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy"))); reader.close(); dir.close(); @@ -203,25 +151,33 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio // "Train" a faiss flat index - this really just creates an empty index that does brute force k-NN long vectorsPointer = JNIService.transferVectors(0, new float[0][0]); - byte [] modelBlob = JNIService.trainIndex(ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "Flat", - SPACE_TYPE, spaceType.getValue()), dimension, vectorsPointer, - KNNEngine.FAISS.getName()); + byte[] modelBlob = JNIService.trainIndex( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "Flat", SPACE_TYPE, spaceType.getValue()), + dimension, + vectorsPointer, + KNNEngine.FAISS.getName() + ); // Setup model cache ModelDao modelDao = mock(ModelDao.class); // Set model state to created - ModelMetadata modelMetadata1 = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); when(modelDao.get(modelId)).thenReturn(mockModel); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata1); Settings settings = settings(CURRENT).put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getSettings()).thenReturn(settings); @@ -242,12 +198,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio fieldType.freeze(); // Add the documents to the index - float[][] arrays = { - {1.0f, 3.0f, 4.0f}, - {2.0f, 5.0f, 8.0f}, - {3.0f, 6.0f, 9.0f}, - {4.0f, 7.0f, 10.0f} - }; + float[][] arrays = { { 1.0f, 3.0f, 4.0f }, { 2.0f, 5.0f, 8.0f }, { 3.0f, 6.0f, 9.0f }, { 4.0f, 7.0f, 10.0f } }; RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); String fieldName = "test_vector"; @@ -265,7 +216,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio KNNWeight.initialize(modelDao); ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - float [] query = {10.0f, 10.0f, 10.0f}; + float[] query = { 10.0f, 10.0f, 10.0f }; IndexSearcher searcher = new IndexSearcher(reader); TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10); @@ -280,4 +231,3 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } } - diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java new file mode 100644 index 000000000..de335a115 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -0,0 +1,421 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.opensearch.common.collect.Set; +import org.opensearch.knn.index.KNNQueryResult; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; +import static org.junit.Assert.assertTrue; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.test.OpenSearchTestCase.randomByteArrayOfLength; + +public class KNNCodecTestUtil { + + // Utility class to help build SegmentInfo with reasonable defaults + public static class SegmentInfoBuilder { + + private final Directory directory; + private final String segmentName; + private final int docsInSegment; + private final Codec codec; + + private Version version; + private Version minVersion; + private boolean isCompoundFile; + private byte[] segmentId; + private final Map attributes; + private Sort indexSort; + + public static SegmentInfoBuilder builder(Directory directory, String segmentName, int docsInSegment, Codec codec) { + return new SegmentInfoBuilder(directory, segmentName, docsInSegment, codec); + } + + private SegmentInfoBuilder(Directory directory, String segmentName, int docsInSegment, Codec codec) { + this.directory = directory; + this.segmentName = segmentName; + this.docsInSegment = docsInSegment; + this.codec = codec; + + this.version = Version.LATEST; + this.minVersion = Version.LATEST; + this.isCompoundFile = false; + this.segmentId = randomByteArrayOfLength(StringHelper.ID_LENGTH); + this.attributes = new HashMap<>(); + this.indexSort = Sort.INDEXORDER; + } + + public SegmentInfoBuilder version(Version version) { + this.version = version; + return this; + } + + public SegmentInfoBuilder minVersion(Version minVersion) { + this.minVersion = minVersion; + return this; + } + + public SegmentInfoBuilder isCompoundFile(boolean isCompoundFile) { + this.isCompoundFile = isCompoundFile; + return this; + } + + public SegmentInfoBuilder segmentId(byte[] segmentId) { + this.segmentId = segmentId; + return this; + } + + public SegmentInfoBuilder addAttribute(String key, String value) { + this.attributes.put(key, value); + return this; + } + + public SegmentInfoBuilder indexSort(Sort indexSort) { + this.indexSort = indexSort; + return this; + } + + public SegmentInfo build() { + return new SegmentInfo( + directory, + version, + minVersion, + segmentName, + docsInSegment, + isCompoundFile, + codec, + Collections.emptyMap(), + segmentId, + attributes, + indexSort + ); + } + } + + // Utility class to help build FieldInfo + public static class FieldInfoBuilder { + private final String fieldName; + private int fieldNumber; + private boolean storeTermVector; + private boolean omitNorms; + private boolean storePayloads; + private IndexOptions indexOptions; + private DocValuesType docValuesType; + private long dvGen; + private final Map attributes; + private int pointDimensionCount; + private int pointIndexDimensionCount; + private int pointNumBytes; + private boolean softDeletes; + + public static FieldInfoBuilder builder(String fieldName) { + return new FieldInfoBuilder(fieldName); + } + + private FieldInfoBuilder(String fieldName) { + this.fieldName = fieldName; + this.fieldNumber = 0; + this.storeTermVector = false; + this.omitNorms = true; + this.storePayloads = true; + this.indexOptions = IndexOptions.NONE; + this.docValuesType = DocValuesType.BINARY; + this.dvGen = 0; + this.attributes = new HashMap<>(); + this.pointDimensionCount = 0; + this.pointIndexDimensionCount = 0; + this.pointNumBytes = 0; + this.softDeletes = false; + } + + public FieldInfoBuilder fieldNumber(int fieldNumber) { + this.fieldNumber = fieldNumber; + return this; + } + + public FieldInfoBuilder storeTermVector(boolean storeTermVector) { + this.storeTermVector = storeTermVector; + return this; + } + + public FieldInfoBuilder omitNorms(boolean omitNorms) { + this.omitNorms = omitNorms; + return this; + } + + public FieldInfoBuilder storePayloads(boolean storePayloads) { + this.storePayloads = storePayloads; + return this; + } + + public FieldInfoBuilder indexOptions(IndexOptions indexOptions) { + this.indexOptions = indexOptions; + return this; + } + + public FieldInfoBuilder docValuesType(DocValuesType docValuesType) { + this.docValuesType = docValuesType; + return this; + } + + public FieldInfoBuilder dvGen(long dvGen) { + this.dvGen = dvGen; + return this; + } + + public FieldInfoBuilder addAttribute(String key, String value) { + this.attributes.put(key, value); + return this; + } + + public FieldInfoBuilder pointDimensionCount(int pointDimensionCount) { + this.pointDimensionCount = pointDimensionCount; + return this; + } + + public FieldInfoBuilder pointIndexDimensionCount(int pointIndexDimensionCount) { + this.pointIndexDimensionCount = pointIndexDimensionCount; + return this; + } + + public FieldInfoBuilder pointNumBytes(int pointNumBytes) { + this.pointNumBytes = pointNumBytes; + return this; + } + + public FieldInfoBuilder softDeletes(boolean softDeletes) { + this.softDeletes = softDeletes; + return this; + } + + public FieldInfo build() { + return new FieldInfo( + fieldName, + fieldNumber, + storeTermVector, + omitNorms, + storePayloads, + indexOptions, + docValuesType, + dvGen, + attributes, + pointDimensionCount, + pointIndexDimensionCount, + pointNumBytes, + softDeletes + ); + } + } + + public static abstract class VectorDocValues extends BinaryDocValues { + + final int count; + final int dimension; + int current; + KNNVectorSerializer knnVectorSerializer; + + public VectorDocValues(int count, int dimension) { + this.count = count; + this.dimension = dimension; + this.current = -1; + this.knnVectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); + } + + @Override + public boolean advanceExact(int target) throws IOException { + return false; + } + + @Override + public int docID() { + if (this.current > this.count) { + return BinaryDocValues.NO_MORE_DOCS; + } + return this.current; + } + + @Override + public int nextDoc() throws IOException { + return advance(current + 1); + } + + @Override + public int advance(int target) throws IOException { + current = target; + if (current >= count) { + current = NO_MORE_DOCS; + } + return current; + } + + @Override + public long cost() { + return 0; + } + } + + public static class ConstantVectorBinaryDocValues extends VectorDocValues { + + private final BytesRef value; + + public ConstantVectorBinaryDocValues(int count, int dimension, float value) { + super(count, dimension); + float[] array = new float[dimension]; + Arrays.fill(array, value); + this.value = new BytesRef(knnVectorSerializer.floatToByteArray(array)); + } + + @Override + public BytesRef binaryValue() throws IOException { + return value; + } + } + + public static class RandomVectorBinaryDocValues extends VectorDocValues { + + public RandomVectorBinaryDocValues(int count, int dimension) { + super(count, dimension); + } + + @Override + public BytesRef binaryValue() throws IOException { + return new BytesRef(knnVectorSerializer.floatToByteArray(getRandomVector(dimension))); + } + } + + public static class RandomVectorDocValuesProducer extends DocValuesProducer { + + final RandomVectorBinaryDocValues randomBinaryDocValues; + + public RandomVectorDocValuesProducer(int count, int dimension) { + this.randomBinaryDocValues = new RandomVectorBinaryDocValues(count, dimension); + } + + @Override + public NumericDocValues getNumeric(FieldInfo field) { + return null; + } + + @Override + public BinaryDocValues getBinary(FieldInfo field) throws IOException { + return randomBinaryDocValues; + } + + @Override + public SortedDocValues getSorted(FieldInfo field) { + return null; + } + + @Override + public SortedNumericDocValues getSortedNumeric(FieldInfo field) { + return null; + } + + @Override + public SortedSetDocValues getSortedSet(FieldInfo field) { + return null; + } + + @Override + public void checkIntegrity() { + + } + + @Override + public void close() throws IOException { + + } + + @Override + public long ramBytesUsed() { + return 0; + } + } + + public static void assertFileInCorrectLocation(SegmentWriteState state, String expectedFile) throws IOException { + assertTrue(Set.of(state.directory.listAll()).contains(expectedFile)); + } + + public static void assertValidFooter(Directory dir, String filename) throws IOException { + ChecksumIndexInput indexInput = dir.openChecksumInput(filename, IOContext.DEFAULT); + indexInput.seek(indexInput.length() - CodecUtil.footerLength()); + CodecUtil.checkFooter(indexInput); + indexInput.close(); + } + + public static void assertLoadableByEngine( + SegmentWriteState state, + String fileName, + KNNEngine knnEngine, + SpaceType spaceType, + int dimension + ) { + String filePath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), fileName) + .toString(); + long indexPtr = JNIService.loadIndex( + filePath, + Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), + knnEngine.getName() + ); + int k = 2; + float[] queryVector = new float[dimension]; + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName()); + assertTrue(results.length > 0); + JNIService.free(indexPtr, knnEngine.getName()); + } + + public static float[][] getRandomVectors(int count, int dimension) { + float[][] data = new float[count][dimension]; + for (int i = 0; i < count; i++) { + data[i] = getRandomVector(dimension); + } + return data; + } + + public static float[] getRandomVector(int dimension) { + float[] data = new float[dimension]; + for (int i = 0; i < dimension; i++) { + data[i] = randomFloat(); + } + return data; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java b/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java new file mode 100644 index 000000000..a2105af3a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.codec.util; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.MergeState; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; + +import java.io.IOException; + +public class BinaryDocValuesSubTests extends KNNTestCase { + + public void testNextDoc() throws IOException { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 2.0f); + MergeState.DocMap docMap = new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }; + + BinaryDocValuesSub binaryDocValuesSub = new BinaryDocValuesSub(docMap, binaryDocValues); + int expectedNextDoc = binaryDocValues.nextDoc() + 1; + assertEquals(expectedNextDoc, binaryDocValuesSub.nextDoc()); + } + + public void testGetValues() { + BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 2.0f); + MergeState.DocMap docMap = new MergeState.DocMap() { + @Override + public int get(int docID) { + return docID; + } + }; + + BinaryDocValuesSub binaryDocValuesSub = new BinaryDocValuesSub(docMap, binaryDocValues); + + assertEquals(binaryDocValues, binaryDocValuesSub.getValues()); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerTests.java similarity index 82% rename from src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java rename to src/test/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerTests.java index 4bc62bebf..1d08df6a0 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerTests.java @@ -3,12 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.codec; +package org.opensearch.knn.index.codec.util; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.codec.util.SerializationMode; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -22,8 +19,8 @@ public class KNNVectorSerializerTests extends KNNTestCase { Random random = new Random(); public void testVectorSerializerFactory() throws Exception { - //check that default serializer can work with array of floats - //setup + // check that default serializer can work with array of floats + // setup final float[] vector = getArrayOfRandomFloats(20); final ByteArrayOutputStream bas = new ByteArrayOutputStream(); final DataOutputStream ds = new DataOutputStream(bas); @@ -40,19 +37,18 @@ public void testVectorSerializerFactory() throws Exception { assertNotNull(actualDeserializedVector); assertArrayEquals(vector, actualDeserializedVector, 0.1f); - final KNNVectorSerializer arraySerializer = - KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY); + final KNNVectorSerializer arraySerializer = KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY); assertNotNull(arraySerializer); - final KNNVectorSerializer collectionOfFloatsSerializer = - KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.COLLECTION_OF_FLOATS); + final KNNVectorSerializer collectionOfFloatsSerializer = KNNVectorSerializerFactory.getSerializerBySerializationMode( + SerializationMode.COLLECTION_OF_FLOATS + ); assertNotNull(collectionOfFloatsSerializer); } - public void testVectorSerializerFactory_throwExceptionForStreamWithUnsupportedDataType() throws Exception { - //prepare array of chars that is not supported by serializer factory. expected behavior is to fail - final char[] arrayOfChars = new char[] {'a', 'b', 'c'}; + // prepare array of chars that is not supported by serializer factory. expected behavior is to fail + final char[] arrayOfChars = new char[] { 'a', 'b', 'c' }; final ByteArrayOutputStream bas = new ByteArrayOutputStream(); final DataOutputStream ds = new DataOutputStream(bas); for (char ch : arrayOfChars) @@ -75,14 +71,14 @@ public void testVectorAsArraySerializer() throws Exception { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(bais); - //testing serialization + // testing serialization bais.reset(); final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); assertNotNull(actualSerializedVector); assertArrayEquals(serializedVector, actualSerializedVector); - //testing deserialization + // testing deserialization bais.reset(); final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(bais); @@ -91,7 +87,7 @@ public void testVectorAsArraySerializer() throws Exception { } public void testVectorAsCollectionOfFloatsSerializer() throws Exception { - //setup + // setup final float[] vector = getArrayOfRandomFloats(20); final ByteArrayOutputStream bas = new ByteArrayOutputStream(); @@ -103,14 +99,14 @@ public void testVectorAsCollectionOfFloatsSerializer() throws Exception { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(bais); - //testing serialization + // testing serialization bais.reset(); final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); assertNotNull(actualSerializedVector); assertArrayEquals(vectorAsCollectionOfFloats, actualSerializedVector); - //testing deserialization + // testing deserialization bais.reset(); final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(bais);