Skip to content

Commit

Permalink
Apply spotless
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Mar 21, 2022
1 parent 12f9bec commit 7711d48
Show file tree
Hide file tree
Showing 14 changed files with 216 additions and 219 deletions.
7 changes: 4 additions & 3 deletions src/main/java/org/opensearch/knn/index/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public int getK() {
return this.k;
}

public String getIndexName() { return this.indexName; }
public String getIndexName() {
return this.indexName;
}

/**
* Constructs Weight implementation for this query
Expand Down Expand Up @@ -77,8 +79,7 @@ public int hashCode() {

@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
return sameClassAs(other) && equalsTo(getClass().cast(other));
}

private boolean equalsTo(KNNQuery other) {
Expand Down
212 changes: 107 additions & 105 deletions src/main/java/org/opensearch/knn/index/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
Expand All @@ -19,7 +18,6 @@
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
Expand All @@ -36,10 +34,8 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -79,108 +75,116 @@ public Explanation explain(LeafReaderContext context, int doc) {

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(),
reader.getSegmentName());
return null;
}

KNNEngine knnEngine;
SpaceType spaceType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else {
String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName());
knnEngine = KNNEngine.getEngine(engineName);
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
spaceType = SpaceType.getSpace(spaceTypeName);
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return null;
}

KNNEngine knnEngine;
SpaceType spaceType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

/*
* In case of compound file, extension would be <engine-extension> + c otherwise <engine-extension>
*/
String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile()
? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION : knnEngine.getExtension();
String engineSuffix = knnQuery.getField() + engineExtension;
List<String> engineFiles = reader.getSegmentInfo().files().stream()
.filter(fileName -> fileName.endsWith(engineSuffix))
.collect(Collectors.toList());

if(engineFiles.isEmpty()) {
logger.debug("[KNN] No engine index found for field {} for segment {}",
knnQuery.getField(), reader.getSegmentName());
return null;
knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else {
String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName());
knnEngine = KNNEngine.getEngine(engineName);
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
spaceType = SpaceType.getSpace(spaceTypeName);
}

/*
* In case of compound file, extension would be <engine-extension> + c otherwise <engine-extension>
*/
String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile()
? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION
: knnEngine.getExtension();
String engineSuffix = knnQuery.getField() + engineExtension;
List<String> engineFiles = reader.getSegmentInfo()
.files()
.stream()
.filter(fileName -> fileName.endsWith(engineSuffix))
.collect(Collectors.toList());

if (engineFiles.isEmpty()) {
logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName());
return null;
}

Path indexPath = PathUtils.get(directory, engineFiles.get(0));
final KNNQueryResult[] results;
KNNCounter.GRAPH_QUERY_REQUESTS.increment();

// We need to first get index allocation
NativeMemoryAllocation indexAllocation;
try {
indexAllocation = nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
indexPath.toString(),
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()),
knnQuery.getIndexName()
),
true
);
} catch (ExecutionException e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
}

// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();

try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

Path indexPath = PathUtils.get(directory, engineFiles.get(0));
final KNNQueryResult[] results;
KNNCounter.GRAPH_QUERY_REQUESTS.increment();

// We need to first get index allocation
NativeMemoryAllocation indexAllocation;
try {
indexAllocation = nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
indexPath.toString(),
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()),
knnQuery.getIndexName()
), true);
} catch (ExecutionException e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
}

// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();

try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName());
} catch (Exception e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
} finally {
indexAllocation.readUnlock();
}

/*
* Scores represent the distance of the documents with respect to given query vector.
* Lesser the score, the closer the document is to the query vector.
* Since by default results are retrieved in the descending order of scores, to get the nearest
* neighbors we are inverting the scores.
*/
if (results.length == 0) {
logger.debug("[KNN] Query yielded 0 results");
return null;
}

Map<Integer, Float> scores = Arrays.stream(results).collect(
Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
int maxDoc = Collections.max(scores.keySet()) + 1;
DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc);
Arrays.stream(results).forEach(result -> setAdder.add(result.getId()));
DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, scores, boost);
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine.getName()
);
} catch (Exception e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
} finally {
indexAllocation.readUnlock();
}

/*
* Scores represent the distance of the documents with respect to given query vector.
* Lesser the score, the closer the document is to the query vector.
* Since by default results are retrieved in the descending order of scores, to get the nearest
* neighbors we are inverting the scores.
*/
if (results.length == 0) {
logger.debug("[KNN] Query yielded 0 results");
return null;
}

Map<Integer, Float> scores = Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
int maxDoc = Collections.max(scores.keySet()) + 1;
DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc);
Arrays.stream(results).forEach(result -> setAdder.add(result.getId()));
DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, scores, boost);
}

@Override
Expand All @@ -189,9 +193,7 @@ public boolean isCacheable(LeafReaderContext context) {
}

public static float normalizeScore(float score) {
if (score >= 0)
return 1 / (1 + score);
if (score >= 0) return 1 / (1 + score);
return -score + 1;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) {
* This function returns the Lucene80 Codec.
*/
public Codec getDelegatee() {
if (lucene80Codec == null)
lucene80Codec = Codec.forName(LUCENE_80);
if (lucene80Codec == null) lucene80Codec = Codec.forName(LUCENE_80);
return lucene80Codec;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public KNN84Codec() {
super(KNN_84);
// Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80
// DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses
this.docValuesFormat = new KNN80DocValuesFormat();
this.docValuesFormat = new KNN80DocValuesFormat();
this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() {
@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
Expand All @@ -57,8 +57,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) {
* This function returns the Lucene84 Codec.
*/
public Codec getDelegatee() {
if (lucene84Codec == null)
lucene84Codec = Codec.forName(LUCENE_84);
if (lucene84Codec == null) lucene84Codec = Codec.forName(LUCENE_84);
return lucene84Codec;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public KNN86Codec() {
super(KNN_86);
// Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80
// DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses
this.docValuesFormat = new KNN80DocValuesFormat();
this.docValuesFormat = new KNN80DocValuesFormat();
this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() {
@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
Expand All @@ -58,8 +58,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) {
* This function returns the Lucene84 Codec.
*/
public Codec getDelegatee() {
if (lucene86Codec == null)
lucene86Codec = Codec.forName(LUCENE_86);
if (lucene86Codec == null) lucene86Codec = Codec.forName(LUCENE_86);
return lucene86Codec;
}

Expand All @@ -74,7 +73,6 @@ public DocValuesFormat docValuesFormat() {
* approach of manually overriding.
*/


public void setPostingsFormat(PostingsFormat postingsFormat) {
this.postingsFormat = postingsFormat;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public final class KNN91Codec extends FilterCodec {
public KNN91Codec() {
this(new Lucene91Codec());
}

/**
* Constructor that takes a Codec delegate to delegate all methods this code does not implement to.
*
Expand Down
5 changes: 1 addition & 4 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.bytes.BytesArray;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -205,7 +202,7 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
return;
}
String mapping = Strings.toString(
JsonXContent.contentBuilder().startObject().startObject(MapperService.SINGLE_MAPPING_NAME).endObject().endObject()
JsonXContent.contentBuilder().startObject().startObject(MapperService.SINGLE_MAPPING_NAME).endObject().endObject()
);
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(mapping)
.settings(
Expand Down
Loading

0 comments on commit 7711d48

Please sign in to comment.