Skip to content

Commit

Permalink
support uploading prebuilt model (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
ylwu-amzn authored Dec 30, 2022
1 parent 26435de commit 6fb7970
Show file tree
Hide file tree
Showing 26 changed files with 674 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ public enum MLModelState {
LOADING,
LOADED,
PARTIALLY_LOADED,
UNLOADED;
UNLOADED,
LOAD_FAILED;

public static MLModelState from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ public MLUploadInput(FunctionName functionName,
if (version == null) {
throw new IllegalArgumentException("model version is null");
}
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (modelConfig == null) {
throw new IllegalArgumentException("model config is null");
}
if (url == null) {
throw new IllegalArgumentException("model file url is null");
if (url != null) {
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (modelConfig == null) {
throw new IllegalArgumentException("model config is null");
}
}
this.modelName = modelName;
this.version = version;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.opensearch.ml.common.transport.upload;

import org.hamcrest.Matcher;
import org.junit.Rule;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -14,22 +13,17 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.*;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.TestHelper;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.search.SearchModule;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;


import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.function.Consumer;

import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;

@RunWith(MockitoJUnitRunner.class)
public class MLUploadInputTest {
Expand Down Expand Up @@ -108,6 +102,7 @@ public void constructor_NullModelFormat() {
.modelName(modelName)
.version(version)
.modelFormat(null)
.url(url)
.build();
}

Expand All @@ -121,20 +116,7 @@ public void constructor_NullModelConfig() {
.version(version)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(null)
.build();
}

@Test
public void constructor_NullModelFileUrl() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model file url is null");
MLUploadInput.builder()
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.url(null)
.url(url)
.build();
}

Expand Down
4 changes: 2 additions & 2 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ dependencies {
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0'
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
implementation platform("ai.djl:bom:0.20.0")
implementation platform("ai.djl:bom:0.19.0")
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
implementation group: 'ai.djl', name: 'api'
implementation group: 'ai.djl.huggingface', name: 'tokenizers'
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.20.0") {
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.19.0") {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.13.1"
Expand Down
10 changes: 10 additions & 0 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.ml.common.output.Output;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

/**
Expand All @@ -33,6 +34,15 @@ public MLEngine(Path opensearchDataFolder) {
djlModelsCachePath = djlCachePath.resolve("models_cache");
}

public String getCIPrebuiltModelConfigPath(String modelName, String version) {
return String.format("https://ci.opensearch.org/ci/dbc/models/ml-models/%s/%s/config.json", modelName, version, Locale.ROOT);
}

public String getCIPrebuiltModelPath(String modelName, String version) {
int index = modelName.lastIndexOf("/") + 1;
return String.format("https://ci.opensearch.org/ci/dbc/models/ml-models/%s/%s/%s.zip", modelName, version, modelName.substring(index), Locale.ROOT);
}

public Path getUploadModelPath(String modelId, String modelName, String version) {
return getUploadModelPath(modelId).resolve(version).resolve(modelName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import com.google.gson.Gson;
import com.google.gson.stream.JsonReader;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.upload.MLUploadInput;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Path;
import java.security.AccessController;
Expand Down Expand Up @@ -38,22 +45,105 @@ public class ModelHelper {
public static final String PYTORCH_ENGINE = "PyTorch";
public static final String ONNX_ENGINE = "OnnxRuntime";
private final MLEngine mlEngine;
private Gson gson;

public ModelHelper(MLEngine mlEngine) {
this.mlEngine = mlEngine;
gson = new Gson();
}

public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput, ActionListener<MLUploadInput> listener) {
String modelName = uploadInput.getModelName();
String version = uploadInput.getVersion();
boolean loadModel = uploadInput.isLoadModel();
String[] modelNodeIds = uploadInput.getModelNodeIds();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {

Path modelUploadPath = mlEngine.getUploadModelPath(taskId, modelName, version);
String configCacheFilePath = modelUploadPath.resolve("config.json").toString();

String configFileUrl = mlEngine.getCIPrebuiltModelConfigPath(modelName, version);
String modelZipFileUrl = mlEngine.getCIPrebuiltModelPath(modelName, version);
DownloadUtils.download(configFileUrl, configCacheFilePath, new ProgressBar());

Map<?, ?> config = null;
try (JsonReader reader = new JsonReader(new FileReader(configCacheFilePath))) {
config = gson.fromJson(reader, Map.class);
}

if (config == null) {
listener.onFailure(new IllegalArgumentException("model config not found"));
return null;
}

MLUploadInput.MLUploadInputBuilder builder = MLUploadInput.builder();

builder.modelName(modelName).version(version).url(modelZipFileUrl).loadModel(loadModel).modelNodeIds(modelNodeIds);
config.entrySet().forEach(entry -> {
switch (entry.getKey().toString()) {
case MLUploadInput.MODEL_FORMAT_FIELD:
builder.modelFormat(MLModelFormat.from(entry.getValue().toString()));
break;
case MLUploadInput.MODEL_CONFIG_FIELD:
TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder configBuilder = TextEmbeddingModelConfig.builder();
Map<?, ?> configMap = (Map<?, ?>) entry.getValue();
for (Map.Entry<?, ?> configEntry : configMap.entrySet()) {
switch (configEntry.getKey().toString()) {
case MLModelConfig.MODEL_TYPE_FIELD:
configBuilder.modelType(configEntry.getValue().toString());
break;
case MLModelConfig.ALL_CONFIG_FIELD:
configBuilder.allConfig(configEntry.getValue().toString());
break;
case TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD:
configBuilder.embeddingDimension(((Double)configEntry.getValue()).intValue());
break;
case TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD:
configBuilder.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.POOLING_METHOD_FIELD:
configBuilder.poolingMethod(TextEmbeddingModelConfig.PoolingMethod.from(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD:
configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double)configEntry.getValue()).intValue());
break;
default:
break;
}
}
builder.modelConfig(configBuilder.build());
break;
default:
break;
}
});
MLUploadInput mlUploadInput = builder.build();
listener.onResponse(mlUploadInput);
return null;
});
} catch (Exception e) {
listener.onFailure(e);
} finally {
deleteFileQuietly(mlEngine.getUploadModelPath(taskId));
}
}

/**
* Download model from URL and split it into smaller chunks.
* @param modelId model id
* @param taskId task id
* @param modelName model name
* @param version model version
* @param url model file URL
* @param listener action listener
*/
public void downloadAndSplit(String modelId, String modelName, String version, String url, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(String taskId, String modelName, String version, String url, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path modelUploadPath = mlEngine.getUploadModelPath(modelId, modelName, version);
Path modelUploadPath = mlEngine.getUploadModelPath(taskId, modelName, version);
String modelPath = modelUploadPath +".zip";
Path modelPartsPath = modelUploadPath.resolve("chunks");
File modelZipFile = new File(modelPath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;

Expand All @@ -38,6 +40,9 @@ public class ModelHelperTest {
@Mock
ActionListener<Map<String, Object>> actionListener;

@Mock
ActionListener<MLUploadInput> uploadInputListener;

@Before
public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
Expand All @@ -63,4 +68,37 @@ public void testDownloadAndSplit() throws URISyntaxException {
assertNotNull(argumentCaptor.getValue());
assertNotEquals(0, argumentCaptor.getValue().size());
}

@Test
public void testDownloadPrebuiltModelConfig_WrongModelName() {
String taskId = "test_task_id";
MLUploadInput unloadInput = MLUploadInput.builder()
.modelName("test_model_name")
.version("1.0.0")
.loadModel(false)
.modelNodeIds(new String[]{"node_id1"})
.build();
modelHelper.downloadPrebuiltModelConfig(taskId, unloadInput, uploadInputListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(uploadInputListener).onFailure(argumentCaptor.capture());
assertEquals(PrivilegedActionException.class, argumentCaptor.getValue().getClass());
}

@Test
public void testDownloadPrebuiltModelConfig() {
String taskId = "test_task_id";
MLUploadInput unloadInput = MLUploadInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.0")
.loadModel(false)
.modelNodeIds(new String[]{"node_id1"})
.build();
modelHelper.downloadPrebuiltModelConfig(taskId, unloadInput, uploadInputListener);
ArgumentCaptor<MLUploadInput> argumentCaptor = ArgumentCaptor.forClass(MLUploadInput.class);
verify(uploadInputListener).onResponse(argumentCaptor.capture());
assertNotNull(argumentCaptor.getValue());
MLModelConfig modelConfig = argumentCaptor.getValue().getModelConfig();
assertNotNull(modelConfig);
assertEquals("mpnet", modelConfig.getModelType());
}
}
2 changes: 1 addition & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies {
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.1'
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
implementation group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.19.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.action.forward;

import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;

import java.time.Instant;
import java.util.Arrays;
Expand Down Expand Up @@ -107,27 +109,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLTask.STATE_FIELD, taskState);
if (mlTaskCache.hasError()) {
builder.put(MLTask.ERROR_FIELD, mlTaskCache.getErrors().toString());
builder.put(MLTask.ERROR_FIELD, toJsonString(mlTaskCache.getErrors()));
}
mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);

MLModelState modelState;
if (!mlTaskCache.allNodeFailed()) {
MLModelState modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_LOADED : MLModelState.LOADED;
log.info("load model done with state: {}, model id: {}", modelState, modelId);
mlModelManager
.updateModel(
modelId,
ImmutableMap
.of(
MLModel.MODEL_STATE_FIELD,
modelState,
MLModel.LAST_LOADED_TIME_FIELD,
Instant.now().toEpochMilli()
)
);
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_LOADED : MLModelState.LOADED;
} else {
modelState = MLModelState.LOAD_FAILED;
log.error("load model failed on all nodes, model id: {}", modelId);
}
log.info("load model done with state: {}, model id: {}", modelState, modelId);
mlModelManager
.updateModel(
modelId,
ImmutableMap
.of(MLModel.MODEL_STATE_FIELD, modelState, MLModel.LAST_LOADED_TIME_FIELD, Instant.now().toEpochMilli())
);
}
listener.onResponse(new MLForwardResponse("ok", null));
break;
Expand All @@ -139,15 +138,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
throw new IllegalArgumentException("unsupported request type");
}
} catch (Exception e) {
log.error("Failed to execute forward action", e);
logException("Failed to execute forward action " + forwardInput.getRequestType(), e, log);
listener.onFailure(e);
}
}

private void syncModelWorkerNodes(String modelId) {
DiscoveryNode[] allNodes = nodeHelper.getAllNodes();
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
if (allNodes.length > 1 && workerNodes.length > 0) {
if (allNodes.length > 1 && workerNodes != null && workerNodes.length > 0) {
log.debug("Sync to other nodes about worker nodes of model {}: {}", modelId, Arrays.toString(workerNodes));
MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes(ImmutableMap.of(modelId, workerNodes)).build();
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
Expand Down
Loading

0 comments on commit 6fb7970

Please sign in to comment.