Skip to content

Commit

Permalink
Spotless applied
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Mar 17, 2022
1 parent e03b574 commit 284eb2a
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 246 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ public class KNN80CompoundFormat extends CompoundFormat {

private final CompoundFormat delegate;


public KNN80CompoundFormat() {
this.delegate = new Lucene50CompoundFormat();
}
Expand All @@ -53,14 +52,12 @@ public void write(Directory dir, SegmentInfo si, IOContext context) throws IOExc
delegate.write(dir, si, context);
}

private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension)
throws IOException {
private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException {
/*
* If engine file present, remove it from the compounding file list to avoid header/footer checks
* and create a new compounding file format with extension engine + c.
*/
Set<String> engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension))
.collect(Collectors.toSet());
Set<String> engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet());

Set<String> segmentFiles = new HashSet<>(si.files());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)

KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getLatestBuildVersion(),
field.name,
knnEngine.getExtension()
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

Expand All @@ -115,10 +119,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getLatestBuildVersion(),
field.name,
knnEngine.getExtension()
);
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName)
.toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

Expand All @@ -134,10 +142,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
try (
IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)
) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
Expand All @@ -148,29 +158,26 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
}
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine,
String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(
KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
AccessController.doPrivileged(
(PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters,
knnEngine.getName());
return null;
}
private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName());
return null;
});
}

private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine,
String indexPath) throws IOException {
private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
throws IOException {
Map<String, Object> parameters = new HashMap<>();
Map<String, String> fieldAttributes = fieldInfo.attributes();
String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS);

// parametersString will be null when legacy mapper is used
if (parametersString == null) {
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE,
SpaceType.DEFAULT.getValue()));
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()));

String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION);
Map<String, Object> algoParams = new HashMap<>();
Expand All @@ -185,22 +192,20 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
parameters.put(PARAMETERS, algoParams);
} else {
parameters.putAll(
XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map()
XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString)
.map()
);
}

// Used to determine how many threads to use when indexing
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(
KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));

// Pass the path for the nms library to save the file
AccessController.doPrivileged(
(PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
return null;
}
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
return null;
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public KNN87Codec(Codec delegate) {
super(KNN_87, delegate);
// Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80
// DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses
this.docValuesFormat = new KNN80DocValuesFormat(delegate.docValuesFormat());
this.docValuesFormat = new KNN80DocValuesFormat(delegate.docValuesFormat());
this.compoundFormat = new KNN80CompoundFormat(delegate.compoundFormat());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ public int nextDoc() throws IOException {
public BinaryDocValues getValues() {
return values;
}
}
}
89 changes: 49 additions & 40 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ public class KNNPlugin extends Plugin implements MapperPlugin, SearchPlugin, Act

@Override
public Map<String, Mapper.TypeParser> getMappers() {
return Collections.singletonMap(KNNVectorFieldMapper.CONTENT_TYPE, new KNNVectorFieldMapper.TypeParser(
ModelDao.OpenSearchKNNModelDao::getInstance));
return Collections.singletonMap(
KNNVectorFieldMapper.CONTENT_TYPE,
new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance)
);
}

@Override
Expand All @@ -149,12 +151,19 @@ public List<QuerySpec<?>> getQueries() {
}

@Override
public Collection<Object> createComponents(Client client, ClusterService clusterService, ThreadPool threadPool,
ResourceWatcherService resourceWatcherService, ScriptService scriptService,
NamedXContentRegistry xContentRegistry, Environment environment,
NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier) {
public Collection<Object> createComponents(
Client client,
ClusterService clusterService,
ThreadPool threadPool,
ResourceWatcherService resourceWatcherService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
Environment environment,
NodeEnvironment nodeEnvironment,
NamedWriteableRegistry namedWriteableRegistry,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
this.clusterService = clusterService;

// Initialize Native Memory loading strategies
Expand All @@ -179,25 +188,35 @@ public List<Setting<?>> getSettings() {
return KNNSettings.state().getSettings();
}

public List<RestHandler> getRestHandlers(Settings settings,
RestController restController,
ClusterSettings clusterSettings,
IndexScopedSettings indexScopedSettings,
SettingsFilter settingsFilter,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster) {
public List<RestHandler> getRestHandlers(
Settings settings,
RestController restController,
ClusterSettings clusterSettings,
IndexScopedSettings indexScopedSettings,
SettingsFilter settingsFilter,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster
) {

RestKNNStatsHandler restKNNStatsHandler = new RestKNNStatsHandler(settings, restController, knnStats);
RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService,
indexNameExpressionResolver);
RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(
settings,
restController,
clusterService,
indexNameExpressionResolver
);
RestGetModelHandler restGetModelHandler = new RestGetModelHandler();
RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler();
RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler();
RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler();

return ImmutableList.of(
restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler,
restTrainModelHandler, restSearchModelHandler
restKNNStatsHandler,
restKNNWarmupHandler,
restGetModelHandler,
restDeleteModelHandler,
restTrainModelHandler,
restSearchModelHandler
);
}

Expand All @@ -207,17 +226,16 @@ public List<RestHandler> getRestHandlers(Settings settings,
@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return Arrays.asList(
new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class),
new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class),
new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class),
new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE,
TrainingJobRouteDecisionInfoTransportAction.class),
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class),
new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class),
new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class),
new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class),
new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class),
new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class)
new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class),
new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class),
new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class),
new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class),
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class),
new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class),
new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class),
new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class),
new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class),
new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class)
);
}

Expand Down Expand Up @@ -273,15 +291,6 @@ public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return ImmutableList.of(
new FixedExecutorBuilder(
settings,
TRAIN_THREAD_POOL,
1,
1,
KNN_THREAD_POOL_PREFIX,
false
)
);
return ImmutableList.of(new FixedExecutorBuilder(settings, TRAIN_THREAD_POOL, 1, 1, KNN_THREAD_POOL_PREFIX, false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

public class KNN80CompoundFormatTests extends KNNTestCase {


private static Directory directory;
private static Codec codec;

Expand All @@ -45,7 +44,6 @@ public static void closeStaticVariables() throws IOException {
directory.close();
}


public void testGetCompoundReader() throws IOException {
CompoundDirectory dir = mock(CompoundDirectory.class);
CompoundFormat delegate = mock(CompoundFormat.class);
Expand All @@ -59,16 +57,15 @@ public void testWrite() throws IOException {
String segmentName = "_test";

Set<String> segmentFiles = Sets.newHashSet(
String.format("%s_nmslib1%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_nmslib2%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_nmslib3%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_faiss1%s", segmentName, KNNEngine.FAISS.getExtension()),
String.format("%s_faiss2%s", segmentName, KNNEngine.FAISS.getExtension()),
String.format("%s_faiss3%s", segmentName, KNNEngine.FAISS.getExtension())
String.format("%s_nmslib1%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_nmslib2%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_nmslib3%s", segmentName, KNNEngine.NMSLIB.getExtension()),
String.format("%s_faiss1%s", segmentName, KNNEngine.FAISS.getExtension()),
String.format("%s_faiss2%s", segmentName, KNNEngine.FAISS.getExtension()),
String.format("%s_faiss3%s", segmentName, KNNEngine.FAISS.getExtension())
);

SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName,
segmentFiles.size(), codec).build();
SegmentInfo segmentInfo = KNNCodecTestUtil.SegmentInfoBuilder.builder(directory, segmentName, segmentFiles.size(), codec).build();

for (String name : segmentFiles) {
IndexOutput indexOutput = directory.createOutput(name, IOContext.DEFAULT);
Expand Down
Loading

0 comments on commit 284eb2a

Please sign in to comment.