diff --git a/src/main/java/org/opensearch/knn/index/KNNQuery.java b/src/main/java/org/opensearch/knn/index/KNNQuery.java index 76709ae7e2..631b36d7a4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/KNNQuery.java @@ -42,7 +42,9 @@ public int getK() { return this.k; } - public String getIndexName() { return this.indexName; } + public String getIndexName() { + return this.indexName; + } /** * Constructs Weight implementation for this query @@ -77,8 +79,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - return sameClassAs(other) && - equalsTo(getClass().cast(other)); + return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(KNNQuery other) { diff --git a/src/main/java/org/opensearch/knn/index/KNNWeight.java b/src/main/java/org/opensearch/knn/index/KNNWeight.java index 3f777c9703..7defb1eeb2 100644 --- a/src/main/java/org/opensearch/knn/index/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/KNNWeight.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index; -import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -19,7 +18,6 @@ import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorer; @@ -36,10 +34,8 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -79,108 +75,116 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); - String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); - - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - - if (fieldInfo == null) { - logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), - reader.getSegmentName()); - return null; - } - - KNNEngine knnEngine; - SpaceType spaceType; - - // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's - // metadata. - String modelId = fieldInfo.getAttribute(MODEL_ID); - if (modelId != null) { - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - throw new RuntimeException("Model \"" + modelId + "\" does not exist."); - } - - knnEngine = modelMetadata.getKnnEngine(); - spaceType = modelMetadata.getSpaceType(); - } else { - String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - knnEngine = KNNEngine.getEngine(engineName); - String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); - spaceType = SpaceType.getSpace(spaceTypeName); + SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); + String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); + + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + + if (fieldInfo == null) { + logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + KNNEngine knnEngine; + SpaceType spaceType; + + // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's + // metadata. + String modelId = fieldInfo.getAttribute(MODEL_ID); + if (modelId != null) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (modelMetadata == null) { + throw new RuntimeException("Model \"" + modelId + "\" does not exist."); } - /* - * In case of compound file, extension would be + c otherwise - */ - String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() - ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION : knnEngine.getExtension(); - String engineSuffix = knnQuery.getField() + engineExtension; - List engineFiles = reader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(engineSuffix)) - .collect(Collectors.toList()); - - if(engineFiles.isEmpty()) { - logger.debug("[KNN] No engine index found for field {} for segment {}", - knnQuery.getField(), reader.getSegmentName()); - return null; + knnEngine = modelMetadata.getKnnEngine(); + spaceType = modelMetadata.getSpaceType(); + } else { + String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); + knnEngine = KNNEngine.getEngine(engineName); + String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); + spaceType = SpaceType.getSpace(spaceTypeName); + } + + /* + * In case of compound file, extension would be + c otherwise + */ + String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() + ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION + : knnEngine.getExtension(); + String engineSuffix = knnQuery.getField() + engineExtension; + List engineFiles = reader.getSegmentInfo() + .files() + .stream() + .filter(fileName -> fileName.endsWith(engineSuffix)) + .collect(Collectors.toList()); + + if (engineFiles.isEmpty()) { + logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + Path indexPath = PathUtils.get(directory, engineFiles.get(0)); + final KNNQueryResult[] results; + KNNCounter.GRAPH_QUERY_REQUESTS.increment(); + + // We need to first get index allocation + NativeMemoryAllocation indexAllocation; + try { + indexAllocation = nativeMemoryCacheManager.get( + new NativeMemoryEntryContext.IndexEntryContext( + indexPath.toString(), + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), + knnQuery.getIndexName() + ), + true + ); + } catch (ExecutionException e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } + + // Now that we have the allocation, we need to readLock it + indexAllocation.readLock(); + + try { + if (indexAllocation.isClosed()) { + throw new RuntimeException("Index has already been closed"); } - Path indexPath = PathUtils.get(directory, engineFiles.get(0)); - final KNNQueryResult[] results; - KNNCounter.GRAPH_QUERY_REQUESTS.increment(); - - // We need to first get index allocation - NativeMemoryAllocation indexAllocation; - try { - indexAllocation = nativeMemoryCacheManager.get( - new NativeMemoryEntryContext.IndexEntryContext( - indexPath.toString(), - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), - knnQuery.getIndexName() - ), true); - } catch (ExecutionException e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } - - // Now that we have the allocation, we need to readLock it - indexAllocation.readLock(); - - try { - if (indexAllocation.isClosed()) { - throw new RuntimeException("Index has already been closed"); - } - - results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName()); - } catch (Exception e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } finally { - indexAllocation.readUnlock(); - } - - /* - * Scores represent the distance of the documents with respect to given query vector. - * Lesser the score, the closer the document is to the query vector. - * Since by default results are retrieved in the descending order of scores, to get the nearest - * neighbors we are inverting the scores. - */ - if (results.length == 0) { - logger.debug("[KNN] Query yielded 0 results"); - return null; - } - - Map scores = Arrays.stream(results).collect( - Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); - int maxDoc = Collections.max(scores.keySet()) + 1; - DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); - DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); - Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); - DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); - return new KNNScorer(this, docIdSetIter, scores, boost); + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnEngine.getName() + ); + } catch (Exception e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } finally { + indexAllocation.readUnlock(); + } + + /* + * Scores represent the distance of the documents with respect to given query vector. + * Lesser the score, the closer the document is to the query vector. + * Since by default results are retrieved in the descending order of scores, to get the nearest + * neighbors we are inverting the scores. + */ + if (results.length == 0) { + logger.debug("[KNN] Query yielded 0 results"); + return null; + } + + Map scores = Arrays.stream(results) + .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); + int maxDoc = Collections.max(scores.keySet()) + 1; + DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); + Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); + DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); + return new KNNScorer(this, docIdSetIter, scores, boost); } @Override @@ -189,9 +193,7 @@ public boolean isCacheable(LeafReaderContext context) { } public static float normalizeScore(float score) { - if (score >= 0) - return 1 / (1 + score); + if (score >= 0) return 1 / (1 + score); return -score + 1; } } - diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java index 1d3bb95801..e635855739 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java @@ -35,6 +35,7 @@ public final class KNN91Codec extends FilterCodec { public KNN91Codec() { this(new Lucene91Codec()); } + /** * Constructor that takes a Codec delegate to delegate all methods this code does not implement to. *