Skip to content

Commit

Permalink
Manually add footer to engine files (opensearch-project#327)
Browse files Browse the repository at this point in the history
Manually adds Lucene footer to engine files to prevent an unnecessary
copy from one file to another.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Mar 21, 2022
1 parent 38f1c23 commit f385aec
Showing 1 changed file with 62 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.store.ChecksumIndexInput;
import org.opensearch.common.xcontent.DeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -20,7 +21,6 @@
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
Expand All @@ -30,20 +30,23 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.opensearch.knn.index.KNNVectorFieldMapper;
import org.opensearch.knn.common.KNNConstants;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.HashMap;
import java.util.Map;

import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
Expand All @@ -55,11 +58,12 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {

private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class);

private final String TEMP_SUFFIX = "tmp";
private DocValuesConsumer delegatee;
private SegmentWriteState state;
private final DocValuesConsumer delegatee;
private final SegmentWriteState state;

KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) throws IOException {
private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L;

KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) {
this.delegatee = delegatee;
this.state = state;
}
Expand All @@ -84,12 +88,10 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)

// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();

// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;

NativeIndexCreator indexCreator;
if (field.attributes().containsKey(MODEL_ID)) {

String modelId = field.attributes().get(MODEL_ID);
Expand All @@ -105,14 +107,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath);
} else {

// Get engine to be used for indexing
Expand All @@ -127,35 +127,16 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
// TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (
IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)
) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Expand Down Expand Up @@ -254,4 +235,45 @@ public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) t
public void close() throws IOException {
delegatee.close();
}

@FunctionalInterface
private interface NativeIndexCreator {
void createIndex() throws IOException;
}

private void writeFooter(String indexPath, String engineFileName) throws IOException {
// Opens the engine file that was created and appends a footer to it. The footer consists of
// 1. A Footer magic number (int - 4 bytes)
// 2. A checksum algorithm id (int - 4 bytes)
// 3. A checksum (long - bytes)
// The checksum is computed on all the bytes written to the file up to that point.
// Logic where footer is written in Lucene can be found here:
// https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412
OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND);
ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN);
byteBuffer.putInt(FOOTER_MAGIC);
byteBuffer.putInt(0);
os.write(byteBuffer.array());
os.flush();

ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context);
checksumIndexInput.seek(checksumIndexInput.length());
long value = checksumIndexInput.getChecksum();
checksumIndexInput.close();

if (isChecksumValid(value)) {
throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")");
}

// Write the CRC checksum to the end of the OutputStream and close the stream
byteBuffer.putLong(0, value);
os.write(byteBuffer.array());
os.close();
}

private boolean isChecksumValid(long value) {
// Check pulled from
// https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647
return (value & CRC32_CHECKSUM_SANITY) != 0;
}
}

0 comments on commit f385aec

Please sign in to comment.