Skip to content

Commit

Permalink
Clean up codec and augment testing
Browse files Browse the repository at this point in the history
Adds unit tests that map to each codec component. Did not add tests for
merging and codec utils. This can be undertaken later. Adds a utils
folder for sharing testing functionality between codec tests. Cleans up
a few minor details around codec source code.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Mar 11, 2022
1 parent 1cc9f24 commit c5ba1dc
Show file tree
Hide file tree
Showing 13 changed files with 1,112 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.util.BytesRef;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,82 +67,84 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
@Override
public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
delegatee.addBinaryField(field, valuesProducer);
addKNNBinaryField(field, valuesProducer);
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(field, valuesProducer);
}
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {

// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}

// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();

// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;
// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;

if (field.attributes().containsKey(MODEL_ID)) {
if (field.attributes().containsKey(MODEL_ID)) {

String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();
KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}
if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {
createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {

// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);
// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}
createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
}

Expand Down Expand Up @@ -214,7 +216,7 @@ public void merge(MergeState mergeState) {
assert mergeState.mergeFieldInfos != null;
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
DocValuesType type = fieldInfo.getDocValuesType();
if (type == DocValuesType.BINARY) {
if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -22,7 +22,7 @@
*/
class KNN80DocValuesReader extends EmptyDocValuesProducer {

private MergeState mergeState;
private final MergeState mergeState;

KNN80DocValuesReader(MergeState mergeState) {
this.mergeState = mergeState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
/**
* KNNCodecService to inject the right KNNCodec version
*/
class KNNCodecService extends CodecService {
public class KNNCodecService extends CodecService {

public KNNCodecService(CodecServiceConfig codecServiceConfig) {
super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec;
package org.opensearch.knn.index.codec.util;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -19,10 +25,6 @@ public class BinaryDocValuesSub extends DocIDMerger.Sub {

private final BinaryDocValues values;

public BinaryDocValues getValues() {
return values;
}

public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) {
super(docMap);
if (values == null || (values.docID() != -1)) {
Expand All @@ -35,4 +37,8 @@ public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) {
public int nextDoc() throws IOException {
return values.nextDoc();
}

public BinaryDocValues getValues() {
return values;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN80Codec;

import com.google.common.collect.ImmutableList;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.MergeState;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;

import java.io.IOException;

public class KNN80BinaryDocValuesTests extends KNNTestCase {

public void testDocId() {
KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null);
assertEquals(-1, knn80BinaryDocValues.docID());
}

public void testNextDoc() throws IOException {
final int expectedDoc = 12;

BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() {
@Override
public int get(int docID) {
return expectedDoc;
}
}, new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f));

DocIDMerger<BinaryDocValuesSub> docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false);
KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger);
assertEquals(expectedDoc, knn80BinaryDocValues.nextDoc());
}

public void testAdvance() {
KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null);
expectThrows(UnsupportedOperationException.class, () -> knn80BinaryDocValues.advance(0));
}

public void testAdvanceExact() {
KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null);
expectThrows(UnsupportedOperationException.class, () -> knn80BinaryDocValues.advanceExact(0));
}

public void testCost() {
KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(null);
expectThrows(UnsupportedOperationException.class, knn80BinaryDocValues::cost);
}

public void testBinaryValue() throws IOException {
BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f);
BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() {
@Override
public int get(int docID) {
return docID;
}
}, binaryDocValues);

DocIDMerger<BinaryDocValuesSub> docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false);
KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger);
knn80BinaryDocValues.nextDoc();
assertEquals(binaryDocValues.binaryValue(), knn80BinaryDocValues.binaryValue());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundDirectory;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.opensearch.common.util.set.Sets;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Arrays;
import java.util.Set;

import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class KNN80CompoundFormatTests extends KNNTestCase {


private static Directory directory;
private static Codec codec;

@BeforeClass
public static void setStaticVariables() {
directory = newFSDirectory(createTempDir());
codec = new KNN87Codec();
}

@AfterClass
public static void closeStaticVariables() throws IOException {
directory.close();
}


public void testGetCompoundReader() throws IOException {
CompoundDirectory dir = mock(CompoundDirectory.class);
CompoundFormat delegate = mock(CompoundFormat.class);
when(delegate.getCompoundReader(null, null, null)).thenReturn(dir);
KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate);
assertEquals(dir, knn80CompoundFormat.getCompoundReader(null, null, null));
}

public void testWrite() throws IOException {
// Check that all normal engine files correctly get set to compound extension files after write
String segmentName = "_test";

Set<String> segmentFiles = Sets.newHashSet(
String.format("%s_nmslib1%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_nmslib2%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_nmslib3%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_faiss1%s", segmentName, KNNEngine.FAISS.getExtension()),
String.format("%s_faiss2%s", segmentName, KNNEngine.FAISS.getExtension()),
String.format("%s_faiss3%s", segmentName, KNNEngine.FAISS.getExtension())
);

SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName,
segmentFiles.size(), codec).build();

for (String name : segmentFiles) {
IndexOutput indexOutput = directory.createOutput(name, IOContext.DEFAULT);
indexOutput.close();
}
segmentInfo.setFiles(segmentFiles);

CompoundFormat delegate = mock(CompoundFormat.class);
doNothing().when(delegate).write(directory, segmentInfo, IOContext.DEFAULT);

KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate);
knn80CompoundFormat.write(directory, segmentInfo, IOContext.DEFAULT);

assertTrue(segmentInfo.files().isEmpty());

Arrays.stream(directory.listAll()).forEach(filename -> {
try {
directory.deleteFile(filename);
} catch (IOException e) {
fail(String.format("Failed to delete: %s", filename));
}
});
}

}
Loading

0 comments on commit c5ba1dc

Please sign in to comment.