diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f2595ee2b3..9e838792b3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,7 +35,7 @@ jobs: - name: Run build run: | - ./gradlew build -Dopensearch.version=2.0.0-SNAPSHOT + ./gradlew build -Dopensearch.version=2.0.0-alpha1-SNAPSHOT - name: Run k-NN Backwards Compatibility Tests run: | diff --git a/build.gradle b/build.gradle index 1c7e6175b7..dc65c8c87b 100644 --- a/build.gradle +++ b/build.gradle @@ -11,9 +11,9 @@ buildscript { ext { // build.version_qualifier parameter applies to knn plugin artifacts only. OpenSearch version must be set // explicitly as 'opensearch.version' property, for instance opensearch.version=2.0.0-alpha1-SNAPSHOT - opensearch_version = System.getProperty("opensearch.version", "2.0.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "2.0.0-alpha1-SNAPSHOT") knn_bwc_version = System.getProperty("bwc.version", "1.2.0.0-SNAPSHOT") - version_qualifier = System.getProperty("build.version_qualifier", "") + version_qualifier = System.getProperty("build.version_qualifier", "alpha1") opensearch_bwc_version = "${knn_bwc_version}" - ".0-SNAPSHOT" opensearch_group = "org.opensearch" } diff --git a/src/main/java/org/opensearch/knn/index/KNNQuery.java b/src/main/java/org/opensearch/knn/index/KNNQuery.java index f3bf23aeef..631b36d7a4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/KNNQuery.java @@ -7,6 +7,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; @@ -41,7 +42,9 @@ public int getK() { return this.k; } - public String getIndexName() { return this.indexName; } + public String getIndexName() { + return this.indexName; + } /** * Constructs Weight implementation for this query @@ -59,6 +62,11 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new KNNWeight(this, boost); } + @Override + public void visit(QueryVisitor visitor) { + + } + @Override public String toString(String field) { return field; @@ -71,8 +79,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - return sameClassAs(other) && - equalsTo(getClass().cast(other)); + return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(KNNQuery other) { diff --git a/src/main/java/org/opensearch/knn/index/KNNWeight.java b/src/main/java/org/opensearch/knn/index/KNNWeight.java index 6c8c41cb99..7defb1eeb2 100644 --- a/src/main/java/org/opensearch/knn/index/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/KNNWeight.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index; -import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -19,7 +18,6 @@ import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorer; @@ -36,10 +34,8 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -77,114 +73,118 @@ public Explanation explain(LeafReaderContext context, int doc) { return Explanation.match(1.0f, "No Explanation"); } - @Override - public void extractTerms(Set terms) { - } - @Override public Scorer scorer(LeafReaderContext context) throws IOException { - SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); - String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); - - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - - if (fieldInfo == null) { - logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), - reader.getSegmentName()); - return null; - } - - KNNEngine knnEngine; - SpaceType spaceType; - - // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's - // metadata. - String modelId = fieldInfo.getAttribute(MODEL_ID); - if (modelId != null) { - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - throw new RuntimeException("Model \"" + modelId + "\" does not exist."); - } - - knnEngine = modelMetadata.getKnnEngine(); - spaceType = modelMetadata.getSpaceType(); - } else { - String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - knnEngine = KNNEngine.getEngine(engineName); - String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); - spaceType = SpaceType.getSpace(spaceTypeName); - } - - /* - * In case of compound file, extension would be + c otherwise - */ - String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() - ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION : knnEngine.getExtension(); - String engineSuffix = knnQuery.getField() + engineExtension; - List engineFiles = reader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(engineSuffix)) - .collect(Collectors.toList()); - - if(engineFiles.isEmpty()) { - logger.debug("[KNN] No engine index found for field {} for segment {}", - knnQuery.getField(), reader.getSegmentName()); - return null; - } - - Path indexPath = PathUtils.get(directory, engineFiles.get(0)); - final KNNQueryResult[] results; - KNNCounter.GRAPH_QUERY_REQUESTS.increment(); - - // We need to first get index allocation - NativeMemoryAllocation indexAllocation; - try { - indexAllocation = nativeMemoryCacheManager.get( - new NativeMemoryEntryContext.IndexEntryContext( - indexPath.toString(), - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), - knnQuery.getIndexName() - ), true); - } catch (ExecutionException e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } - - // Now that we have the allocation, we need to readLock it - indexAllocation.readLock(); - - try { - if (indexAllocation.isClosed()) { - throw new RuntimeException("Index has already been closed"); - } - - results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName()); - } catch (Exception e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } finally { - indexAllocation.readUnlock(); + SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); + String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); + + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + + if (fieldInfo == null) { + logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + KNNEngine knnEngine; + SpaceType spaceType; + + // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's + // metadata. + String modelId = fieldInfo.getAttribute(MODEL_ID); + if (modelId != null) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (modelMetadata == null) { + throw new RuntimeException("Model \"" + modelId + "\" does not exist."); } - /* - * Scores represent the distance of the documents with respect to given query vector. - * Lesser the score, the closer the document is to the query vector. - * Since by default results are retrieved in the descending order of scores, to get the nearest - * neighbors we are inverting the scores. - */ - if (results.length == 0) { - logger.debug("[KNN] Query yielded 0 results"); - return null; + knnEngine = modelMetadata.getKnnEngine(); + spaceType = modelMetadata.getSpaceType(); + } else { + String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); + knnEngine = KNNEngine.getEngine(engineName); + String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); + spaceType = SpaceType.getSpace(spaceTypeName); + } + + /* + * In case of compound file, extension would be + c otherwise + */ + String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() + ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION + : knnEngine.getExtension(); + String engineSuffix = knnQuery.getField() + engineExtension; + List engineFiles = reader.getSegmentInfo() + .files() + .stream() + .filter(fileName -> fileName.endsWith(engineSuffix)) + .collect(Collectors.toList()); + + if (engineFiles.isEmpty()) { + logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + Path indexPath = PathUtils.get(directory, engineFiles.get(0)); + final KNNQueryResult[] results; + KNNCounter.GRAPH_QUERY_REQUESTS.increment(); + + // We need to first get index allocation + NativeMemoryAllocation indexAllocation; + try { + indexAllocation = nativeMemoryCacheManager.get( + new NativeMemoryEntryContext.IndexEntryContext( + indexPath.toString(), + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), + knnQuery.getIndexName() + ), + true + ); + } catch (ExecutionException e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } + + // Now that we have the allocation, we need to readLock it + indexAllocation.readLock(); + + try { + if (indexAllocation.isClosed()) { + throw new RuntimeException("Index has already been closed"); } - Map scores = Arrays.stream(results).collect( - Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); - int maxDoc = Collections.max(scores.keySet()) + 1; - DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); - DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); - Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); - DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); - return new KNNScorer(this, docIdSetIter, scores, boost); + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnEngine.getName() + ); + } catch (Exception e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } finally { + indexAllocation.readUnlock(); + } + + /* + * Scores represent the distance of the documents with respect to given query vector. + * Lesser the score, the closer the document is to the query vector. + * Since by default results are retrieved in the descending order of scores, to get the nearest + * neighbors we are inverting the scores. + */ + if (results.length == 0) { + logger.debug("[KNN] Query yielded 0 results"); + return null; + } + + Map scores = Arrays.stream(results) + .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); + int maxDoc = Collections.max(scores.keySet()) + 1; + DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); + Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); + DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); + return new KNNScorer(this, docIdSetIter, scores, boost); } @Override @@ -193,9 +193,7 @@ public boolean isCacheable(LeafReaderContext context) { } public static float normalizeScore(float score) { - if (score >= 0) - return 1 / (1 + score); + if (score >= 0) return 1 / (1 + score); return -score + 1; } } - diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java index 59655762e4..0064f49fef 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java @@ -11,6 +11,7 @@ import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FieldInfosFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.LiveDocsFormat; import org.apache.lucene.codecs.NormsFormat; import org.apache.lucene.codecs.PointsFormat; @@ -52,8 +53,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene80 Codec. */ public Codec getDelegatee() { - if (lucene80Codec == null) - lucene80Codec = Codec.forName(LUCENE_80); + if (lucene80Codec == null) lucene80Codec = Codec.forName(LUCENE_80); return lucene80Codec; } @@ -112,4 +112,9 @@ public CompoundFormat compoundFormat() { public PointsFormat pointsFormat() { return getDelegatee().pointsFormat(); } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + throw new UnsupportedOperationException("Codec does not support knn vector format"); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index 8a8ed558e1..d001cd1baa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -5,7 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat; +import org.apache.lucene.backward_codecs.lucene50.Lucene50CompoundFormat; import org.opensearch.knn.common.KNNConstants; import org.apache.lucene.codecs.CompoundDirectory; import org.apache.lucene.codecs.CompoundFormat; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java index cd8362cf21..fe329eb1c8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java @@ -8,7 +8,7 @@ import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat; +import org.apache.lucene.backward_codecs.lucene80.Lucene80DocValuesFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java index a50f396a48..b55365c094 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN84Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.apache.logging.log4j.LogManager; @@ -42,7 +43,7 @@ public KNN84Codec() { super(KNN_84); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -56,8 +57,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene84 Codec. */ public Codec getDelegatee() { - if (lucene84Codec == null) - lucene84Codec = Codec.forName(LUCENE_84); + if (lucene84Codec == null) lucene84Codec = Codec.forName(LUCENE_84); return lucene84Codec; } @@ -116,4 +116,9 @@ public CompoundFormat compoundFormat() { public PointsFormat pointsFormat() { return getDelegatee().pointsFormat(); } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return getDelegatee().knnVectorsFormat(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java index 70c75e09b5..a3b34559a5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN86Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.apache.logging.log4j.LogManager; @@ -43,7 +44,7 @@ public KNN86Codec() { super(KNN_86); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -57,8 +58,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene84 Codec. */ public Codec getDelegatee() { - if (lucene86Codec == null) - lucene86Codec = Codec.forName(LUCENE_86); + if (lucene86Codec == null) lucene86Codec = Codec.forName(LUCENE_86); return lucene86Codec; } @@ -73,7 +73,6 @@ public DocValuesFormat docValuesFormat() { * approach of manually overriding. */ - public void setPostingsFormat(PostingsFormat postingsFormat) { this.postingsFormat = postingsFormat; } @@ -125,4 +124,9 @@ public CompoundFormat compoundFormat() { public PointsFormat pointsFormat() { return getDelegatee().pointsFormat(); } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return getDelegatee().knnVectorsFormat(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java index 6ec5ec05c1..20799648c1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index.codec.KNN87Codec; import org.apache.lucene.codecs.FilterCodec; -import org.apache.lucene.codecs.lucene87.Lucene87Codec; +import org.apache.lucene.backward_codecs.lucene87.Lucene87Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.apache.lucene.codecs.Codec; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java new file mode 100644 index 0000000000..0acaccfbf3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.codec.KNN910Codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.KNNFormatFactory; + +import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate; + +/** + * Extends the Codec to support a new file format for KNN index + * based on the mappings. + * + */ +public final class KNN910Codec extends FilterCodec { + + private static final String KNN910 = "KNN910Codec"; + private final KNNFormatFacade knnFormatFacade; + + /** + * No arg constructor that uses Lucene91 as the delegate + */ + public KNN910Codec() { + this(createKNN91DefaultDelegate()); + } + + /** + * Constructor that takes a Codec delegate to delegate all methods this code does not implement to. + * + * @param delegate codec that will perform all operations this codec does not override + */ + public KNN910Codec(Codec delegate) { + super(KNN910, delegate); + knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate); + } + + @Override + public DocValuesFormat docValuesFormat() { + return knnFormatFacade.docValuesFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return knnFormatFacade.compoundFormat(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java new file mode 100644 index 0000000000..9662e9edbf --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.codec; + +import com.google.common.collect.ImmutableMap; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene91.Lucene91Codec; +import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; + +import java.lang.reflect.Constructor; +import java.util.Map; + +/** + * Factory abstraction for KNN codec + */ +public class KNNCodecFactory { + + private static Map CODEC_BY_VERSION = ImmutableMap.of(KNNCodecVersion.KNN910, KNN910Codec.class); + + private static KNNCodecVersion LATEST_KNN_CODEC_VERSION = KNNCodecVersion.KNN910; + + public static Codec createKNNCodec(final Codec userCodec) { + return getCodec(LATEST_KNN_CODEC_VERSION, userCodec); + } + + public static Codec createKNNCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) { + return getCodec(knnCodecVersion, userCodec); + } + + private static Codec getCodec(final KNNCodecVersion knnCodecVersion, final Codec userCodec) { + try { + Constructor constructor = CODEC_BY_VERSION.getOrDefault(knnCodecVersion, CODEC_BY_VERSION.get(LATEST_KNN_CODEC_VERSION)) + .getConstructor(Codec.class); + return (Codec) constructor.newInstance(userCodec); + } catch (Exception ex) { + throw new RuntimeException("Cannot create instance of KNN codec", ex); + } + } + + /** + * Factory abstraction for codec delegate + */ + public static class CodecDelegateFactory { + + public static Codec createKNN91DefaultDelegate() { + return new Lucene91Codec(); + } + } + + /** + * Collection of supported coded versions + */ + enum KNNCodecVersion { + KNN910 + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 70991dfe8d..cae6f7fb82 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -6,7 +6,6 @@ package org.opensearch.knn.index.codec; import org.opensearch.index.codec.CodecServiceConfig; -import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; @@ -27,6 +26,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return new KNN87Codec(super.codec(name)); + return KNNCodecFactory.createKNNCodec(super.codec(name)); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNFormatFacade.java b/src/main/java/org/opensearch/knn/index/codec/KNNFormatFacade.java new file mode 100644 index 0000000000..bf9b1ad7f1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNFormatFacade.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.codec; + +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; + +/** + * Class abstracts facade for plugin formats. + */ +public class KNNFormatFacade { + + private final DocValuesFormat docValuesFormat; + private final CompoundFormat compoundFormat; + + public KNNFormatFacade(final DocValuesFormat docValuesFormat, final CompoundFormat compoundFormat) { + this.docValuesFormat = docValuesFormat; + this.compoundFormat = compoundFormat; + } + + public DocValuesFormat docValuesFormat() { + return docValuesFormat; + } + + public CompoundFormat compoundFormat() { + return compoundFormat; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java b/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java new file mode 100644 index 0000000000..6742dc3f42 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.codec; + +import org.apache.lucene.codecs.Codec; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; + +/** + * Factory abstraction for KNN format facade creation + */ +public class KNNFormatFactory { + + public static KNNFormatFacade createKNN910Format(final Codec delegate) { + final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ); + return knnFormatFacade; + } +} diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 385b30225c..937e5f811d 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -40,7 +40,6 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.transport.DeleteModelResponse; @@ -199,8 +198,7 @@ public void create(ActionListener actionListener) throws IO if (isCreated()) { return; } - - CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping("_doc", getMapping(), XContentType.JSON) + CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) .settings( Settings.builder() .put("index.hidden", true) diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index 8e64afa086..b897dc36aa 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -1,4 +1,5 @@ org.opensearch.knn.index.codec.KNN80Codec.KNN80Codec org.opensearch.knn.index.codec.KNN84Codec.KNN84Codec org.opensearch.knn.index.codec.KNN86Codec.KNN86Codec -org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec \ No newline at end of file +org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec +org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/TestUtils.java b/src/test/java/org/opensearch/knn/TestUtils.java index f4968a2558..03b50280f1 100644 --- a/src/test/java/org/opensearch/knn/TestUtils.java +++ b/src/test/java/org/opensearch/knn/TestUtils.java @@ -21,7 +21,6 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.plugin.script.KNNScoringUtil; -import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier; import java.util.Comparator; import java.util.Random; import java.util.Set; @@ -30,13 +29,14 @@ import java.util.List; import java.util.HashSet; import java.util.Map; -import static org.apache.lucene.util.LuceneTestCase.random; + +import static org.apache.lucene.tests.util.LuceneTestCase.random; class DistVector { public float dist; public String docID; - public DistVector (float dist, String docID) { + public DistVector(float dist, String docID) { this.dist = dist; this.docID = docID; } @@ -117,10 +117,10 @@ public static List> computeGroundTruthValues(float[][] indexVectors, } if (pq.size() < k) { - pq.add(new DistVector(dist, String.valueOf(j+1))); + pq.add(new DistVector(dist, String.valueOf(j + 1))); } else if (pq.peek().getDist() > dist) { pq.poll(); - pq.add(new DistVector(dist, String.valueOf(j+1))); + pq.add(new DistVector(dist, String.valueOf(j + 1))); } } @@ -137,7 +137,7 @@ public static List> computeGroundTruthValues(float[][] indexVectors, public static float[][] getQueryVectors(int queryCount, int dimensions, int docCount, boolean isStandard) { if (isStandard) { - return randomlyGenerateStandardVectors(queryCount, dimensions, docCount+1); + return randomlyGenerateStandardVectors(queryCount, dimensions, docCount + 1); } else { return generateRandomVectors(queryCount, dimensions); } @@ -169,8 +169,8 @@ public static double calculateRecallValue(List> searchResults, List recalls.add(recallVal / k); } - double sum = recalls.stream().reduce((a,b)->a+b).get(); - return sum/recalls.size(); + double sum = recalls.stream().reduce((a, b) -> a + b).get(); + return sum / recalls.size(); } /** @@ -192,14 +192,15 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { BufferedReader reader = new BufferedReader(new FileReader(path)); String line = reader.readLine(); while (line != null) { - Map doc = XContentFactory.xContent(XContentType.JSON).createParser( - NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, line).map(); + Map doc = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, line) + .map(); idsList.add((Integer) doc.get("id")); @SuppressWarnings("unchecked") ArrayList vector = (ArrayList) doc.get("vector"); Float[] floatArray = new Float[vector.size()]; - for (int i =0; i< vector.size(); i++) { + for (int i = 0; i < vector.size(); i++) { floatArray[i] = vector.get(i).floatValue(); } vectorsList.add(floatArray); @@ -208,7 +209,7 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { } reader.close(); - int[] idsArray = new int [idsList.size()]; + int[] idsArray = new int[idsList.size()]; float[][] vectorsArray = new float[vectorsList.size()][vectorsList.get(0).length]; for (int i = 0; i < idsList.size(); i++) { idsArray[i] = idsList.get(i); diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index e1feb9f180..e0b2c05cfe 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -14,7 +14,9 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; +import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; @@ -45,14 +47,23 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException // "Train" a faiss flat index - this really just creates an empty index that does brute force k-NN long vectorsPointer = JNIService.transferVectors(0, new float[0][0]); - byte [] modelBlob = JNIService.trainIndex(ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "Flat", - SPACE_TYPE, spaceType.getValue()), dimension, vectorsPointer, - KNNEngine.FAISS.getName()); + byte[] modelBlob = JNIService.trainIndex( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "Flat", SPACE_TYPE, spaceType.getValue()), + dimension, + vectorsPointer, + KNNEngine.FAISS.getName() + ); // Setup model - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -64,34 +75,31 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException String indexName = "test-index"; String fieldName = "test-field"; + final String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject() + ); + modelDao.put(model, ActionListener.wrap(indexResponse -> { - CreateIndexRequestBuilder createIndexRequestBuilder = client().admin().indices().prepareCreate(indexName) - .setSettings(Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build() - ).addMapping( - "_doc", ImmutableMap.of( - "properties", ImmutableMap.of( - fieldName, ImmutableMap.of( - "type", "knn_vector", - "model_id", modelId - ) - ) - ) - ); - - client().admin().indices().create(createIndexRequestBuilder.request(), - ActionListener.wrap( - createIndexResponse -> { - assertTrue(createIndexResponse.isAcknowledged()); - inProgressLatch.countDown(); - }, e -> fail("Unable to create index: " + e.getMessage()) - ) - ); - - }, e ->fail("Unable to put model: " + e.getMessage()))); + CreateIndexRequestBuilder createIndexRequestBuilder = client().admin() + .indices() + .prepareCreate(indexName) + .setSettings(Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build()) + .setMapping(mapping); + + client().admin().indices().create(createIndexRequestBuilder.request(), ActionListener.wrap(createIndexResponse -> { + assertTrue(createIndexResponse.isAcknowledged()); + inProgressLatch.countDown(); + }, e -> fail("Unable to create index: " + e.getMessage()))); + + }, e -> fail("Unable to put model: " + e.getMessage()))); assertTrue(inProgressLatch.await(20, TimeUnit.SECONDS)); } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index 87d00e5547..8bda1aefc7 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index; import org.opensearch.knn.KNNTestCase; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; @@ -43,9 +43,11 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, new float[]{1.0f, 2.0f}, new FieldType()).binaryValue())); + new BinaryDocValuesField( + MOCK_INDEX_FIELD_NAME, + new VectorField(MOCK_INDEX_FIELD_NAME, new float[] { 1.0f, 2.0f }, new FieldType()).binaryValue() + ) + ); knnDocument.add(new NumericDocValuesField(MOCK_NUMERIC_INDEX_FIELD_NAME, 1000)); writer.addDocument(knnDocument); writer.commit(); @@ -67,16 +69,14 @@ public void testGetScriptValues() { } public void testGetScriptValuesWrongFieldName() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), "invalid"); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid"); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); } public void testGetScriptValuesWrongFieldType() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); - expectThrows(IllegalStateException.class, ()->leafFieldData.getScriptValues()); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); + expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues()); } public void testRamBytesUsed() { @@ -86,7 +86,6 @@ public void testRamBytesUsed() { public void testGetBytesValues() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); - expectThrows(UnsupportedOperationException.class, - () -> leafFieldData.getBytesValues()); + expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java index 54435db992..8523c4146f 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index; import org.opensearch.knn.KNNTestCase; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; @@ -72,13 +72,14 @@ public void testLoadDirect() throws IOException { public void testSortField() { - expectThrows(UnsupportedOperationException.class, - () -> indexFieldData.sortField(null, null, null, false)); + expectThrows(UnsupportedOperationException.class, () -> indexFieldData.sortField(null, null, null, false)); } public void testNewBucketedSort() { - expectThrows(UnsupportedOperationException.class, - () -> indexFieldData.newBucketedSort(null, null, null, null, null, null, 0, null)); + expectThrows( + UnsupportedOperationException.class, + () -> indexFieldData.newBucketedSort(null, null, null, null, null, null, 0, null) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 8883bf4ddd..8761179409 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index; import org.opensearch.knn.KNNTestCase; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; @@ -23,7 +23,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; - private static final float[] SAMPLE_VECTOR_DATA = new float[]{1.0f, 2.0f}; + private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; @@ -36,7 +36,9 @@ public void setUp() throws Exception { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME); + leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME + ); } private void createKNNVectorDocument(Directory directory) throws IOException { @@ -44,9 +46,11 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue())); + new BinaryDocValuesField( + MOCK_INDEX_FIELD_NAME, + new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -64,8 +68,7 @@ public void testGetValue() throws IOException { Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); } - - //Test getValue without calling setNextDocId + // Test getValue without calling setNextDocId public void testGetValueFails() throws IOException { expectThrows(IllegalStateException.class, () -> scriptDocValues.getValue()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java index b7f909ec11..ff93fd3a8d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87CodecTests.java @@ -7,16 +7,9 @@ import org.opensearch.knn.index.codec.KNNCodecTestCase; -import java.io.IOException; -import java.util.concurrent.ExecutionException; - public class KNN87CodecTests extends KNNCodecTestCase { - public void testMultiFieldsKnnIndex() throws Exception { - testMultiFieldsKnnIndex(new KNN87Codec()); - } - - public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { - testBuildFromModelTemplate(new KNN87Codec()); + public void testWriteByOldCodec() throws Exception { + testWriteByOldCodec(new KNN87Codec()); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910CodecTests.java new file mode 100644 index 0000000000..1ec28d6a44 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910CodecTests.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN910Codec; + +import org.opensearch.knn.index.codec.KNNCodecTestCase; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class KNN910CodecTests extends KNNCodecTestCase { + + public void testMultiFieldsKnnIndex() throws Exception { + testMultiFieldsKnnIndex(new KNN910Codec()); + } + + public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { + testBuildFromModelTemplate(new KNN910Codec()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java new file mode 100644 index 0000000000..c209312532 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene91.Lucene91Codec; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; + +public class KNNCodecFactoryTests extends KNNTestCase { + + public void testKNN91DefaultDelegate() { + Codec knn91DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate(); + assertNotNull(knn91DefaultDelegate); + assertTrue(knn91DefaultDelegate instanceof Lucene91Codec); + } + + public void testKNN91DefaultCodec() { + Lucene91Codec lucene91CodecDelegate = new Lucene91Codec(); + Codec knnCodec = KNNCodecFactory.createKNNCodec(lucene91CodecDelegate); + assertNotNull(knnCodec); + assertTrue(knnCodec instanceof KNN910Codec); + } + + public void testKNN91CodecByVersion() { + Lucene91Codec lucene91CodecDelegate = new Lucene91Codec(); + Codec knnCodec = KNNCodecFactory.createKNNCodec(KNNCodecFactory.KNNCodecVersion.KNN910, lucene91CodecDelegate); + assertNotNull(knnCodec); + assertTrue(knnCodec instanceof KNN910Codec); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 3454c7d524..2251308a08 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -11,6 +11,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNQuery; import org.opensearch.knn.index.KNNSettings; @@ -18,13 +19,12 @@ import org.opensearch.knn.index.KNNWeight; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorField; -import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.store.Directory; @@ -61,6 +61,7 @@ */ public class KNNCodecTestCase extends KNNTestCase { + private static final KNN910Codec ACTUAL_CODEC = new KNN910Codec(); private static FieldType sampleFieldType; static { sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); @@ -107,7 +108,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { */ IndexWriterConfig iwc1 = newIndexWriterConfig(); iwc1.setMergeScheduler(new SerialMergeScheduler()); - iwc1.setCodec(new KNN87Codec()); + iwc1.setCodec(ACTUAL_CODEC); writer = new RandomIndexWriter(random(), dir, iwc1); float[] array1 = { 6.0f, 14.0f }; VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); @@ -230,4 +231,26 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio resourceWatcherService.close(); NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } + + public void testWriteByOldCodec(Codec codec) throws IOException { + setUpMockClusterService(); + Directory dir = newFSDirectory(createTempDir()); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setCodec(codec); + + /** + * Add doc with field "test_vector", expect it to fail + */ + float[] array = { 1.0f, 3.0f, 4.0f }; + VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); + try (RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc)) { + Document doc = new Document(); + doc.add(vectorField); + expectThrows(UnsupportedOperationException.class, () -> writer.addDocument(doc)); + } + + dir.close(); + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index de335a115e..153fd2faf1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -20,6 +20,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; @@ -145,6 +146,8 @@ public static class FieldInfoBuilder { private int pointDimensionCount; private int pointIndexDimensionCount; private int pointNumBytes; + private int vectorDimension; + private VectorSimilarityFunction vectorSimilarityFunction; private boolean softDeletes; public static FieldInfoBuilder builder(String fieldName) { @@ -164,6 +167,8 @@ private FieldInfoBuilder(String fieldName) { this.pointDimensionCount = 0; this.pointIndexDimensionCount = 0; this.pointNumBytes = 0; + this.vectorDimension = 0; + this.vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; this.softDeletes = false; } @@ -222,6 +227,16 @@ public FieldInfoBuilder pointNumBytes(int pointNumBytes) { return this; } + public FieldInfoBuilder vectorDimension(int vectorDimension) { + this.vectorDimension = vectorDimension; + return this; + } + + public FieldInfoBuilder vectorSimilarityFunction(VectorSimilarityFunction vectorSimilarityFunction) { + this.vectorSimilarityFunction = vectorSimilarityFunction; + return this; + } + public FieldInfoBuilder softDeletes(boolean softDeletes) { this.softDeletes = softDeletes; return this; @@ -241,6 +256,8 @@ public FieldInfo build() { pointDimensionCount, pointIndexDimensionCount, pointNumBytes, + vectorDimension, + vectorSimilarityFunction, softDeletes ); } @@ -364,11 +381,6 @@ public void checkIntegrity() { public void close() throws IOException { } - - @Override - public long ramBytesUsed() { - return 0; - } } public static void assertFileInCorrectLocation(SegmentWriteState state, String expectedFile) throws IOException { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java new file mode 100644 index 0000000000..545e2a1419 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.apache.lucene.codecs.Codec; +import org.opensearch.knn.KNNTestCase; + +public class KNNFormatFactoryTests extends KNNTestCase { + + public void testKNN91Format() { + final Codec lucene91CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate(); + final Codec knnCodec = KNNCodecFactory.createKNNCodec(lucene91CodecDelegate); + KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN910Format(knnCodec); + + assertNotNull(knnFormatFacade); + assertNotNull(knnFormatFacade.compoundFormat()); + assertNotNull(knnFormatFacade.docValuesFormat()); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 08cede77cb..7f66e909a8 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -8,7 +8,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.VectorField; -import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; @@ -35,36 +35,35 @@ private List getTestQueryVector() { } public void testL2SquaredScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; Float distance = KNNScoringUtil.l2Squared(queryVector, inputVector); assertTrue(distance == 27.0f); } public void testWrongDimensionL2SquaredScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.l2Squared(queryVector, inputVector)); } public void testCosineSimilScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); float dotProduct = 12.0f; float expectedScore = (float) (dotProduct / (Math.sqrt(queryVectorMagnitude * inputVectorMagnitude))); - Float actualScore = KNNScoringUtil.cosinesimil(queryVector, inputVector); assertEquals(expectedScore, actualScore, 0.0001); } public void testCosineSimilOptimizedScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); float dotProduct = 12.0f; @@ -86,26 +85,26 @@ public void testConvertInvalidVectorToPrimitive() { } public void testCosineSimilQueryVectorZeroMagnitude() { - float[] queryVector = {0, 0}; - float[] inputVector = {4.0f, 4.0f}; + float[] queryVector = { 0, 0 }; + float[] inputVector = { 4.0f, 4.0f }; assertEquals(0, KNNScoringUtil.cosinesimil(queryVector, inputVector), 0.00001); } public void testCosineSimilOptimizedQueryVectorZeroMagnitude() { - float[] inputVector = {4.0f, 4.0f}; - float[] queryVector = {0, 0}; + float[] inputVector = { 4.0f, 4.0f }; + float[] queryVector = { 0, 0 }; assertTrue(0 == KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 0.0f)); } public void testWrongDimensionCosineSimilScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimil(queryVector, inputVector)); } public void testWrongDimensionCosineSimilOPtimizedScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 1.0f)); } @@ -173,7 +172,7 @@ public void testBitHammingDistance_Long() { public void testL2SquaredWhitelistedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); Float distance = KNNScoringUtil.l2Squared(queryVector, scriptDocValues); @@ -184,7 +183,7 @@ public void testL2SquaredWhitelistedScoringFunction() throws IOException { public void testScriptDocValuesFailsL2() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.l2Squared(queryVector, scriptDocValues)); dataset.close(); @@ -193,7 +192,7 @@ public void testScriptDocValuesFailsL2() throws IOException { public void testCosineSimilarityScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); @@ -205,7 +204,7 @@ public void testCosineSimilarityScoringFunction() throws IOException { public void testScriptDocValuesFailsCosineSimilarity() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues)); dataset.close(); @@ -214,7 +213,7 @@ public void testScriptDocValuesFailsCosineSimilarity() throws IOException { public void testCosineSimilarityOptimizedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f); @@ -225,7 +224,7 @@ public void testCosineSimilarityOptimizedScoringFunction() throws IOException { public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f)); dataset.close(); @@ -244,16 +243,14 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName),fieldName ); + scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName); } return scriptDocValues; } public void close() throws IOException { - if (reader != null) - reader.close(); - if (directory != null) - directory.close(); + if (reader != null) reader.close(); + if (directory != null) directory.close(); } public void createKNNVectorDocument(final float[] content, final String fieldName) throws IOException { @@ -261,10 +258,7 @@ public void createKNNVectorDocument(final float[] content, final String fieldNam IndexWriter writer = new IndexWriter(directory, conf); conf.setMergePolicy(NoMergePolicy.INSTANCE); // prevent merges for this test Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - fieldName, - new VectorField(fieldName, content, new FieldType()).binaryValue())); + knnDocument.add(new BinaryDocValuesField(fieldName, new VectorField(fieldName, content, new FieldType()).binaryValue())); writer.addDocument(knnDocument); writer.commit(); writer.close();