diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 0bbd6e9ac..36ffe16de 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.store.ChecksumIndexInput; import org.opensearch.common.xcontent.DeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentFactory; @@ -20,7 +21,6 @@ import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; @@ -30,20 +30,23 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.IOUtils; import org.opensearch.knn.index.KNNVectorFieldMapper; import org.opensearch.knn.common.KNNConstants; import java.io.Closeable; import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.HashMap; import java.util.Map; +import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; @@ -55,11 +58,12 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class); - private final String TEMP_SUFFIX = "tmp"; - private DocValuesConsumer delegatee; - private SegmentWriteState state; + private final DocValuesConsumer delegatee; + private final SegmentWriteState state; - KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) throws IOException { + private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; + + KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) { this.delegatee = delegatee; this.state = state; } @@ -84,12 +88,10 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) // Increment counter for number of graph index requests KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - // Create library index either from model or from scratch String engineFileName; String indexPath; - String tmpEngineFileName; - + NativeIndexCreator indexCreator; if (field.attributes().containsKey(MODEL_ID)) { String modelId = field.attributes().get(MODEL_ID); @@ -105,14 +107,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) ); indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) .toString(); - tmpEngineFileName = engineFileName + TEMP_SUFFIX; - String tempIndexPath = indexPath + TEMP_SUFFIX; if (model.getModelBlob() == null) { throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); } - createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath); + indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); } else { // Get engine to be used for indexing @@ -127,35 +127,16 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) ); indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) .toString(); - tmpEngineFileName = engineFileName + TEMP_SUFFIX; - String tempIndexPath = indexPath + TEMP_SUFFIX; - createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath); + indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } - /* - * Adds Footer to the serialized graph - * 1. Copies the serialized graph to new file. - * 2. Adds Footer to the new file. - * - * We had to create new file here because adding footer directly to the - * existing file will miss calculating checksum for the serialized graph - * bytes and result in index corruption issues. - */ - // TODO: I think this can be refactored to avoid this copy and then write - // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 - try ( - IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); - IndexOutput os = state.directory.createOutput(engineFileName, state.context) - ) { - os.copyBytes(is, is.length()); - CodecUtil.writeFooter(os); - } catch (Exception ex) { - KNNCounter.GRAPH_INDEX_ERRORS.increment(); - throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex); - } finally { - IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName); - } + // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that + // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will + // not be marked as added to the directory. + state.directory.createOutput(engineFileName, state.context).close(); + indexCreator.createIndex(); + writeFooter(indexPath, engineFileName); } private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { @@ -254,4 +235,45 @@ public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) t public void close() throws IOException { delegatee.close(); } + + @FunctionalInterface + private interface NativeIndexCreator { + void createIndex() throws IOException; + } + + private void writeFooter(String indexPath, String engineFileName) throws IOException { + // Opens the engine file that was created and appends a footer to it. The footer consists of + // 1. A Footer magic number (int - 4 bytes) + // 2. A checksum algorithm id (int - 4 bytes) + // 3. A checksum (long - bytes) + // The checksum is computed on all the bytes written to the file up to that point. + // Logic where footer is written in Lucene can be found here: + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 + OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); + ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); + byteBuffer.putInt(FOOTER_MAGIC); + byteBuffer.putInt(0); + os.write(byteBuffer.array()); + os.flush(); + + ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); + checksumIndexInput.seek(checksumIndexInput.length()); + long value = checksumIndexInput.getChecksum(); + checksumIndexInput.close(); + + if (isChecksumValid(value)) { + throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); + } + + // Write the CRC checksum to the end of the OutputStream and close the stream + byteBuffer.putLong(0, value); + os.write(byteBuffer.array()); + os.close(); + } + + private boolean isChecksumValid(long value) { + // Check pulled from + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 + return (value & CRC32_CHECKSUM_SANITY) != 0; + } }