Skip to content

Commit

Permalink
check model format when upload model (#795)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Mar 10, 2023
1 parent e547452 commit bff96ec
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import com.google.common.annotations.VisibleForTesting;
import com.google.gson.Gson;
import com.google.gson.stream.JsonReader;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand Down Expand Up @@ -136,13 +138,14 @@ public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput

/**
* Download model from URL and split it into smaller chunks.
* @param modelFormat model format
* @param taskId task id
* @param modelName model name
* @param version model version
* @param url model file URL
* @param listener action listener
*/
public void downloadAndSplit(String taskId, String modelName, String version, String url, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path modelUploadPath = mlEngine.getUploadModelPath(taskId, modelName, version);
Expand All @@ -151,7 +154,7 @@ public void downloadAndSplit(String taskId, String modelName, String version, St
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelPath);
verifyModelZipFile(modelFormat, modelPath);

List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Map<String, Object> result = new HashMap<>();
Expand All @@ -168,30 +171,40 @@ public void downloadAndSplit(String taskId, String modelName, String version, St
}
}

private void verifyModelZipFile(String modelZipFilePath) throws IOException {
boolean hasModelFile = false;
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
try (ZipFile zipFile = new ZipFile(modelZipFilePath)) {
Enumeration zipEntries = zipFile.entries();
while (zipEntries.hasMoreElements()) {
String fileName = ((ZipEntry) zipEntries.nextElement()).getName();
if (fileName.endsWith(PYTORCH_FILE_EXTENSION) || fileName.endsWith(ONNX_FILE_EXTENSION)) {
if (hasModelFile) {
throw new IllegalArgumentException("Find multiple model files, but expected only one");
}
hasModelFile = true;
}
hasPtFile = hasModelFile(modelFormat, MLModelFormat.TORCH_SCRIPT, PYTORCH_FILE_EXTENSION, hasPtFile, fileName);
hasOnnxFile = hasModelFile(modelFormat, MLModelFormat.ONNX, ONNX_FILE_EXTENSION, hasOnnxFile, fileName);
if (fileName.equals(TOKENIZER_FILE_NAME)) {
hasTokenizerFile = true;
}
}
}
if (!hasModelFile) {
if (!hasPtFile && !hasOnnxFile) {
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
throw new IllegalArgumentException("Can't find tokenizer file");
throw new IllegalArgumentException("No tokenizer file");
}
}

private static boolean hasModelFile(MLModelFormat modelFormat, MLModelFormat targetModelFormat, String fileExtension, boolean hasModelFile, String fileName) {
if (fileName.endsWith(fileExtension)) {
if (modelFormat != targetModelFormat) {
throw new IllegalArgumentException("Model format is " + modelFormat + ", but find " + fileExtension + " file");
}
if (hasModelFile) {
throw new IllegalArgumentException("Find multiple model files, but expected only one");
}
return true;
}
return hasModelFile;
}

public void deleteFileCache(String modelId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.ml.engine.algorithms.text_embedding;

import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand All @@ -15,15 +14,16 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.security.PrivilegedActionException;
import java.util.Map;
import java.util.UUID;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
Expand All @@ -35,6 +35,7 @@ public class ModelHelperTest {
public ExpectedException exceptionRule = ExpectedException.none();

private ModelHelper modelHelper;
private MLModelFormat modelFormat;
private String modelId;
private MLEngine mlEngine;

Expand All @@ -47,14 +48,15 @@ public class ModelHelperTest {
@Before
public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId));
modelHelper = new ModelHelper(mlEngine);
}

@Test
public void testDownloadAndSplit_UrlFailure() {
modelHelper.downloadAndSplit(modelId, "model_name", "1", "http://testurl", actionListener);
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", "http://testurl", actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(PrivilegedActionException.class, argumentCaptor.getValue().getClass());
Expand All @@ -63,20 +65,58 @@ public void testDownloadAndSplit_UrlFailure() {
@Test
public void testDownloadAndSplit() throws URISyntaxException {
String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString();
modelHelper.downloadAndSplit(modelId, "model_name", "1", modelUrl, actionListener);
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, actionListener);
ArgumentCaptor<Map> argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(actionListener).onResponse(argumentCaptor.capture());
assertNotNull(argumentCaptor.getValue());
assertNotEquals(0, argumentCaptor.getValue().size());
}

@Ignore
@Test
public void testVerifyModelZipFile() throws IOException {
String modelUrl = getClass().getResource("traced_small_model.zip").toString().substring(5);
modelHelper.verifyModelZipFile(modelFormat, modelUrl);
}

@Test
public void testVerifyModelZipFile_WrongModelFormat_ONNX() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Model format is TORCH_SCRIPT, but find .onnx file");
String modelUrl = getClass().getResource("traced_small_model_wrong_onnx.zip").toString().substring(5);
modelHelper.verifyModelZipFile(modelFormat, modelUrl);
}

@Test
public void testVerifyModelZipFile_WrongModelFormat_TORCH_SCRIPT() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Model format is ONNX, but find .pt file");
String modelUrl = getClass().getResource("traced_small_model_wrong_onnx.zip").toString().substring(5);
modelHelper.verifyModelZipFile(MLModelFormat.ONNX, modelUrl);
}

@Test
public void testVerifyModelZipFile_DuplicateModelFile() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Find multiple model files, but expected only one");
String modelUrl = getClass().getResource("traced_small_model_duplicate_pt.zip").toString().substring(5);
modelHelper.verifyModelZipFile(modelFormat, modelUrl);
}

@Test
public void testVerifyModelZipFile_MissingTokenizer() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No tokenizer file");
String modelUrl = getClass().getResource("traced_small_model_missing_tokenizer.zip").toString().substring(5);
modelHelper.verifyModelZipFile(modelFormat, modelUrl);
}

@Test
public void testDownloadPrebuiltModelConfig_WrongModelName() {
String taskId = "test_task_id";
MLUploadInput unloadInput = MLUploadInput.builder()
.modelName("test_model_name")
.version("1.0.0")
.version("1.0.1")
.modelFormat(modelFormat)
.loadModel(false)
.modelNodeIds(new String[]{"node_id1"})
.build();
Expand All @@ -86,13 +126,13 @@ public void testDownloadPrebuiltModelConfig_WrongModelName() {
assertEquals(PrivilegedActionException.class, argumentCaptor.getValue().getClass());
}

@Ignore
@Test
public void testDownloadPrebuiltModelConfig() {
String taskId = "test_task_id";
MLUploadInput unloadInput = MLUploadInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.0")
.version("1.0.1")
.modelFormat(modelFormat)
.loadModel(false)
.modelNodeIds(new String[]{"node_id1"})
.build();
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
128 changes: 68 additions & 60 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -266,69 +266,77 @@ private void uploadModel(
String version,
String modelId
) {
modelHelper.downloadAndSplit(modelId, modelName, version, uploadInput.getUrl(), ActionListener.wrap(result -> {
Long modelSizeInBytes = (Long) result.get(MODEL_SIZE_IN_BYTES);
if (modelSizeInBytes >= MODEL_FILE_SIZE_LIMIT) {
throw new MLException("Model file size exceeds the limit of 4GB: " + modelSizeInBytes);
}
List<String> chunkFiles = (List<String>) result.get(CHUNK_FILES);
String hashValue = (String) result.get(MODEL_FILE_HASH);
Semaphore semaphore = new Semaphore(1);
AtomicInteger uploaded = new AtomicInteger(0);
AtomicBoolean failedToUploadChunk = new AtomicBoolean(false);
// upload chunks
for (String name : chunkFiles) {
semaphore.tryAcquire(10, TimeUnit.SECONDS);
if (failedToUploadChunk.get()) {
throw new MLException("Failed to save model chunk");
}
File file = new File(name);
byte[] bytes = Files.toByteArray(file);
int chunkNum = Integer.parseInt(file.getName());
Instant now = Instant.now();
MLModel mlModel = MLModel
.builder()
.modelId(modelId)
.name(modelName)
.algorithm(functionName)
.version(version)
.modelFormat(uploadInput.getModelFormat())
.chunkNumber(chunkNum)
.totalChunks(chunkFiles.size())
.content(Base64.getEncoder().encodeToString(bytes))
.createdTime(now)
.lastUpdateTime(now)
.build();
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
String chunkId = getModelChunkId(modelId, chunkNum);
indexRequest.id(chunkId);
indexRequest.source(mlModel.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.wrap(r -> {
uploaded.getAndIncrement();
if (uploaded.get() == chunkFiles.size()) {
updateModelUploadStateAsDone(uploadInput, taskId, modelId, modelSizeInBytes, chunkFiles, hashValue);
} else {
deleteFileQuietly(file);
modelHelper
.downloadAndSplit(
uploadInput.getModelFormat(),
modelId,
modelName,
version,
uploadInput.getUrl(),
ActionListener.wrap(result -> {
Long modelSizeInBytes = (Long) result.get(MODEL_SIZE_IN_BYTES);
if (modelSizeInBytes >= MODEL_FILE_SIZE_LIMIT) {
throw new MLException("Model file size exceeds the limit of 4GB: " + modelSizeInBytes);
}
List<String> chunkFiles = (List<String>) result.get(CHUNK_FILES);
String hashValue = (String) result.get(MODEL_FILE_HASH);
Semaphore semaphore = new Semaphore(1);
AtomicInteger uploaded = new AtomicInteger(0);
AtomicBoolean failedToUploadChunk = new AtomicBoolean(false);
// upload chunks
for (String name : chunkFiles) {
semaphore.tryAcquire(10, TimeUnit.SECONDS);
if (failedToUploadChunk.get()) {
throw new MLException("Failed to save model chunk");
}
File file = new File(name);
byte[] bytes = Files.toByteArray(file);
int chunkNum = Integer.parseInt(file.getName());
Instant now = Instant.now();
MLModel mlModel = MLModel
.builder()
.modelId(modelId)
.name(modelName)
.algorithm(functionName)
.version(version)
.modelFormat(uploadInput.getModelFormat())
.chunkNumber(chunkNum)
.totalChunks(chunkFiles.size())
.content(Base64.getEncoder().encodeToString(bytes))
.createdTime(now)
.lastUpdateTime(now)
.build();
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
String chunkId = getModelChunkId(modelId, chunkNum);
indexRequest.id(chunkId);
indexRequest.source(mlModel.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.wrap(r -> {
uploaded.getAndIncrement();
if (uploaded.get() == chunkFiles.size()) {
updateModelUploadStateAsDone(uploadInput, taskId, modelId, modelSizeInBytes, chunkFiles, hashValue);
} else {
deleteFileQuietly(file);
}
semaphore.release();
}, e -> {
log.error("Failed to index model chunk " + chunkId, e);
failedToUploadChunk.set(true);
handleException(functionName, taskId, e);
deleteFileQuietly(file);
// remove model doc as failed to upload model
deleteModel(modelId);
semaphore.release();
deleteFileQuietly(mlEngine.getUploadModelPath(modelId));
}));
}
semaphore.release();
}, e -> {
log.error("Failed to index model chunk " + chunkId, e);
failedToUploadChunk.set(true);
handleException(functionName, taskId, e);
deleteFileQuietly(file);
// remove model doc as failed to upload model
deleteModel(modelId);
semaphore.release();
log.error("Failed to index chunk file", e);
deleteFileQuietly(mlEngine.getUploadModelPath(modelId));
}));
}
}, e -> {
log.error("Failed to index chunk file", e);
deleteFileQuietly(mlEngine.getUploadModelPath(modelId));
deleteModel(modelId);
handleException(functionName, taskId, e);
}));
deleteModel(modelId);
handleException(functionName, taskId, e);
})
);
}

private void uploadPrebuiltModel(MLUploadInput uploadInput, MLTask mlTask) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.stats.MLNodeLevelStat;

import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -167,7 +166,7 @@ private void dispatchTaskWithLeastLoad(ActionListener<DiscoveryNode> listener) {
private void dispatchTaskWithRoundRobin(ActionListener<DiscoveryNode> listener) {
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes();
if (eligibleNodes == null || eligibleNodes.length == 0) {
throw new MLResourceNotFoundException(
throw new IllegalArgumentException(
"No eligible node found to execute this request. It's best practice to"
+ " provision ML nodes to serve your models. You can disable this setting to serve the model on your data"
+ " node for development purposes by disabling the \"plugins.ml_commons.only_run_on_ml_node\" "
Expand Down
Loading

0 comments on commit bff96ec

Please sign in to comment.