Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually add footer to engine files #327

Merged
merged 3 commits into from
Mar 21, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.store.ChecksumIndexInput;
import org.opensearch.common.xcontent.DeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -20,7 +21,6 @@
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
Expand All @@ -30,20 +30,23 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.opensearch.knn.index.KNNVectorFieldMapper;
import org.opensearch.knn.common.KNNConstants;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.HashMap;
import java.util.Map;

import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
Expand All @@ -55,11 +58,10 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {

private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class);

private final String TEMP_SUFFIX = "tmp";
private DocValuesConsumer delegatee;
private SegmentWriteState state;
private final DocValuesConsumer delegatee;
private final SegmentWriteState state;

KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) throws IOException {
KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) {
this.delegatee = delegatee;
this.state = state;
}
Expand All @@ -84,12 +86,10 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)

// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();

// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;

NativeIndexCreator indexCreator;
if (field.attributes().containsKey(MODEL_ID)) {

String modelId = field.attributes().get(MODEL_ID);
Expand All @@ -105,14 +105,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath);
} else {

// Get engine to be used for indexing
Expand All @@ -127,35 +125,16 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
// TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (
IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)
) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Expand Down Expand Up @@ -254,4 +233,39 @@ public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) t
public void close() throws IOException {
delegatee.close();
}

@FunctionalInterface
private interface NativeIndexCreator {
void createIndex() throws IOException;
}

private void writeFooter(String indexPath, String engineFileName) throws IOException {
// Opens the engine file that was created and appends a footer to it. The footer consists of
// 1. A Footer magic number (int - 4 bytes)
// 2. A checksum algorithm id (int - 4 bytes)
// 3. A checksum (long - bytes)
// The checksum is computed on all the bytes written to the file up to that point.
// Logic where footer is written in Lucene can be found here:
// https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412
OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND);
ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN);
byteBuffer.putInt(FOOTER_MAGIC);
byteBuffer.putInt(0);
os.write(byteBuffer.array());
os.flush();

ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context);
checksumIndexInput.seek(checksumIndexInput.length());
long value = checksumIndexInput.getChecksum();
checksumIndexInput.close();

if ((value & 0xFFFFFFFF00000000L) != 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use a constant here? also it would make code more readable if condition will be in a separate method with name that describes the idea of the check, e.g. invalidChecksum()

throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")");
}

// Write the CRC checksum to the end of the OutputStream and close the stream
byteBuffer.putLong(0, value);
os.write(byteBuffer.array());
os.close();
}
}