Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor KNNCodec to use new extension point #319

Merged
merged 12 commits into from
Mar 18, 2022
6 changes: 6 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@
#

version=1.0.0

org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.util.BytesRef;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat;
import org.opensearch.knn.common.KNNConstants;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundDirectory;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.index.SegmentInfo;
Expand All @@ -26,32 +24,40 @@
*/
public class KNN80CompoundFormat extends CompoundFormat {

private final Logger logger = LogManager.getLogger(KNN80CompoundFormat.class);
private final CompoundFormat delegate;

public KNN80CompoundFormat() {
this.delegate = new Lucene50CompoundFormat();
}

/**
* Constructor that takes a delegate to handle non-overridden methods
*
* @param delegate CompoundFormat that will handle non-overridden methods
*/
public KNN80CompoundFormat(CompoundFormat delegate) {
this.delegate = delegate;
}

@Override
public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException {
return Codec.getDefault().compoundFormat().getCompoundReader(dir, si, context);
return delegate.getCompoundReader(dir, si, context);
}

@Override
public void write(Directory dir, SegmentInfo si, IOContext context) throws IOException {
for (KNNEngine knnEngine : KNNEngine.values()) {
writeEngineFiles(dir, si, context, knnEngine.getExtension());
}
Codec.getDefault().compoundFormat().write(dir, si, context);
delegate.write(dir, si, context);
}

private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension)
throws IOException {
private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException {
/*
* If engine file present, remove it from the compounding file list to avoid header/footer checks
* and create a new compounding file format with extension engine + c.
*/
Set<String> engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension))
.collect(Collectors.toSet());
Set<String> engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet());

Set<String> segmentFiles = new HashSet<>(si.files());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,108 +67,117 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
@Override
public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
delegatee.addBinaryField(field, valuesProducer);
addKNNBinaryField(field, valuesProducer);
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(field, valuesProducer);
}
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {

// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}

// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();

if (field.attributes().containsKey(MODEL_ID)) {
// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;

String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (field.attributes().containsKey(MODEL_ID)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need null check for attributes before calling contains?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}
engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getLatestBuildVersion(),
field.name,
knnEngine.getExtension()
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {
if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);
createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}
engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getLatestBuildVersion(),
field.name,
knnEngine.getExtension()
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
// TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (
IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)
) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine,
String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(
KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
AccessController.doPrivileged(
(PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters,
knnEngine.getName());
return null;
}
private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName());
return null;
});
}

private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine,
String indexPath) throws IOException {
private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
throws IOException {
Map<String, Object> parameters = new HashMap<>();
Map<String, String> fieldAttributes = fieldInfo.attributes();
String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS);

// parametersString will be null when legacy mapper is used
if (parametersString == null) {
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE,
SpaceType.DEFAULT.getValue()));
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()));

String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION);
Map<String, Object> algoParams = new HashMap<>();
Expand All @@ -183,22 +192,20 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
parameters.put(PARAMETERS, algoParams);
} else {
parameters.putAll(
XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map()
XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString)
.map()
);
}

// Used to determine how many threads to use when indexing
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(
KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));

// Pass the path for the nms library to save the file
AccessController.doPrivileged(
(PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
return null;
}
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
return null;
});
}

/**
Expand All @@ -214,7 +221,7 @@ public void merge(MergeState mergeState) {
assert mergeState.mergeFieldInfos != null;
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
DocValuesType type = fieldInfo.getDocValuesType();
if (type == DocValuesType.BINARY) {
if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

Expand All @@ -19,11 +18,20 @@
* Encodes/Decodes per document values
*/
public class KNN80DocValuesFormat extends DocValuesFormat {
private final Logger logger = LogManager.getLogger(KNN80DocValuesFormat.class);
private final DocValuesFormat delegate = DocValuesFormat.forName(KNN80Codec.LUCENE_80);
private final DocValuesFormat delegate;

public KNN80DocValuesFormat() {
super(KNN80Codec.LUCENE_80);
this(new Lucene80DocValuesFormat());
}

/**
* Constructor that takes delegate in order to handle non-overridden methods
*
* @param delegate DocValuesFormat to handle non-overridden methods
*/
public KNN80DocValuesFormat(DocValuesFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -22,7 +22,7 @@
*/
class KNN80DocValuesReader extends EmptyDocValuesProducer {

private MergeState mergeState;
private final MergeState mergeState;

KNN80DocValuesReader(MergeState mergeState) {
this.mergeState = mergeState;
Expand Down
Loading