Skip to content

Commit

Permalink
Refactor KNNCodec to use new extension point (#319)
Browse files Browse the repository at this point in the history
Refactor plugin to return CodecServiceFactory as opposed to
CodecService. This will allow the plugin to make decisions based on
Mapper Service.

Refactors the KNN87Codec to implement FilterCodec. This allows the codec
to automatically/flexibly delegate operations it does not override to an
arbitrary Codec. Additionally cleans up some code around the Codec

Adds unit tests that map to each codec component. Did not add tests for
merging and codec utils. This can be undertaken later. Adds a utils
folder for sharing testing functionality between codec tests. Cleans up
a few minor details around codec source code.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Mar 18, 2022
1 parent ccbc0db commit d531b3c
Show file tree
Hide file tree
Showing 20 changed files with 1,327 additions and 417 deletions.
6 changes: 6 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@
#

version=1.0.0

org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.util.BytesRef;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat;
import org.opensearch.knn.common.KNNConstants;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundDirectory;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.index.SegmentInfo;
Expand All @@ -26,32 +24,40 @@
*/
public class KNN80CompoundFormat extends CompoundFormat {

private final Logger logger = LogManager.getLogger(KNN80CompoundFormat.class);
private final CompoundFormat delegate;

public KNN80CompoundFormat() {
this.delegate = new Lucene50CompoundFormat();
}

/**
* Constructor that takes a delegate to handle non-overridden methods
*
* @param delegate CompoundFormat that will handle non-overridden methods
*/
public KNN80CompoundFormat(CompoundFormat delegate) {
this.delegate = delegate;
}

@Override
public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException {
return Codec.getDefault().compoundFormat().getCompoundReader(dir, si, context);
return delegate.getCompoundReader(dir, si, context);
}

@Override
public void write(Directory dir, SegmentInfo si, IOContext context) throws IOException {
for (KNNEngine knnEngine : KNNEngine.values()) {
writeEngineFiles(dir, si, context, knnEngine.getExtension());
}
Codec.getDefault().compoundFormat().write(dir, si, context);
delegate.write(dir, si, context);
}

private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension)
throws IOException {
private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException {
/*
* If engine file present, remove it from the compounding file list to avoid header/footer checks
* and create a new compounding file format with extension engine + c.
*/
Set<String> engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension))
.collect(Collectors.toSet());
Set<String> engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet());

Set<String> segmentFiles = new HashSet<>(si.files());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,108 +67,117 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
@Override
public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
delegatee.addBinaryField(field, valuesProducer);
addKNNBinaryField(field, valuesProducer);
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(field, valuesProducer);
}
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {

// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}

// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();

if (field.attributes().containsKey(MODEL_ID)) {
// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;

String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (field.attributes().containsKey(MODEL_ID)) {

KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}
engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getLatestBuildVersion(),
field.name,
knnEngine.getExtension()
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {
if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);
createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}
engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getLatestBuildVersion(),
field.name,
knnEngine.getExtension()
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
// TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (
IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)
) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine,
String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(
KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
AccessController.doPrivileged(
(PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters,
knnEngine.getName());
return null;
}
private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName());
return null;
});
}

private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine,
String indexPath) throws IOException {
private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
throws IOException {
Map<String, Object> parameters = new HashMap<>();
Map<String, String> fieldAttributes = fieldInfo.attributes();
String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS);

// parametersString will be null when legacy mapper is used
if (parametersString == null) {
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE,
SpaceType.DEFAULT.getValue()));
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()));

String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION);
Map<String, Object> algoParams = new HashMap<>();
Expand All @@ -183,22 +192,20 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
parameters.put(PARAMETERS, algoParams);
} else {
parameters.putAll(
XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map()
XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString)
.map()
);
}

// Used to determine how many threads to use when indexing
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(
KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));

// Pass the path for the nms library to save the file
AccessController.doPrivileged(
(PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
return null;
}
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
return null;
});
}

/**
Expand All @@ -214,7 +221,7 @@ public void merge(MergeState mergeState) {
assert mergeState.mergeFieldInfos != null;
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
DocValuesType type = fieldInfo.getDocValuesType();
if (type == DocValuesType.BINARY) {
if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

Expand All @@ -19,11 +18,20 @@
* Encodes/Decodes per document values
*/
public class KNN80DocValuesFormat extends DocValuesFormat {
private final Logger logger = LogManager.getLogger(KNN80DocValuesFormat.class);
private final DocValuesFormat delegate = DocValuesFormat.forName(KNN80Codec.LUCENE_80);
private final DocValuesFormat delegate;

public KNN80DocValuesFormat() {
super(KNN80Codec.LUCENE_80);
this(new Lucene80DocValuesFormat());
}

/**
* Constructor that takes delegate in order to handle non-overridden methods
*
* @param delegate DocValuesFormat to handle non-overridden methods
*/
public KNN80DocValuesFormat(DocValuesFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -22,7 +22,7 @@
*/
class KNN80DocValuesReader extends EmptyDocValuesProducer {

private MergeState mergeState;
private final MergeState mergeState;

KNN80DocValuesReader(MergeState mergeState) {
this.mergeState = mergeState;
Expand Down
Loading

0 comments on commit d531b3c

Please sign in to comment.