Skip to content

Commit

Permalink
enable prebuilt model (opensearch-project#729)
Browse files Browse the repository at this point in the history
* enable prebuilt model

Signed-off-by: Yaliang Wu <[email protected]>

* address comments

Signed-off-by: Yaliang Wu <[email protected]>

* add unit test for prebuilt model url

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Mar 4, 2023
1 parent b65b990 commit bfdc4f6
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 39 deletions.
4 changes: 0 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ ext {
noticeFile = rootProject.file('NOTICE.txt')
}

dependencies {
implementation 'junit:junit:${versions.junit}'
}

// updateVersion: Task to auto increment to the next development iteration
task updateVersion {
onlyIf { System.getProperty('newVersion') }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public String getName() {

public static PoolingMode from(String value) {
try {
return PoolingMode.valueOf(value);
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");
}
Expand All @@ -197,7 +197,7 @@ public enum FrameworkType {

public static FrameworkType from(String value) {
try {
return FrameworkType.valueOf(value);
return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong framework type");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,12 @@ public MLUploadInput(FunctionName functionName,
if (version == null) {
throw new IllegalArgumentException("model version is null");
}
//TODO: enable prebuilt model in 2.6
// if (url != null) {
// if (modelFormat == null) {
// throw new IllegalArgumentException("model format is null");
// }
// if (modelConfig == null) {
// throw new IllegalArgumentException("model config is null");
// }
// }
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (modelConfig == null) {
if (url != null && modelConfig == null) {
throw new IllegalArgumentException("model config is null");
}
if (url == null) {
throw new IllegalArgumentException("model file url is null");
}
this.modelName = modelName;
this.version = version;
this.description = description;
Expand Down
18 changes: 13 additions & 5 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;

Expand All @@ -25,6 +26,8 @@
*/
public class MLEngine {

private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";

@Getter
private final Path djlCachePath;
private final Path djlModelsCachePath;
Expand All @@ -34,13 +37,18 @@ public MLEngine(Path opensearchDataFolder) {
djlModelsCachePath = djlCachePath.resolve("models_cache");
}

public String getCIPrebuiltModelConfigPath(String modelName, String version) {
return String.format("https://ci.opensearch.org/ci/dbc/models/ml-models/%s/%s/config.json", modelName, version, Locale.ROOT);
public String getPrebuiltModelConfigPath(String modelName, String version, MLModelFormat modelFormat) {
String format = modelFormat.name().toLowerCase(Locale.ROOT);
return String.format("%s/%s/%s/%s/config.json", MODEL_REPO, modelName, version, format, Locale.ROOT);
}

public String getCIPrebuiltModelPath(String modelName, String version) {
int index = modelName.lastIndexOf("/") + 1;
return String.format("https://ci.opensearch.org/ci/dbc/models/ml-models/%s/%s/%s.zip", modelName, version, modelName.substring(index), Locale.ROOT);
public String getPrebuiltModelPath(String modelName, String version, MLModelFormat modelFormat) {
int index = modelName.indexOf("/") + 1;
// /huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.0/onnx/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.0-torch_script.zip
// /huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.0/onnx/config.json
String format = modelFormat.name().toLowerCase(Locale.ROOT);
String modelZipFileName = modelName.substring(index).replace("/", "_") + "-" + version + "-" + format;
return String.format("%s/%s/%s/%s/%s.zip", MODEL_REPO, modelName, version, format, modelZipFileName, Locale.ROOT);
}

public Path getUploadModelPath(String modelId, String modelName, String version) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
Expand Down Expand Up @@ -55,6 +56,7 @@ public ModelHelper(MLEngine mlEngine) {
public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput, ActionListener<MLUploadInput> listener) {
String modelName = uploadInput.getModelName();
String version = uploadInput.getVersion();
MLModelFormat modelFormat = uploadInput.getModelFormat();
boolean loadModel = uploadInput.isLoadModel();
String[] modelNodeIds = uploadInput.getModelNodeIds();
try {
Expand All @@ -63,8 +65,8 @@ public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput
Path modelUploadPath = mlEngine.getUploadModelPath(taskId, modelName, version);
String configCacheFilePath = modelUploadPath.resolve("config.json").toString();

String configFileUrl = mlEngine.getCIPrebuiltModelConfigPath(modelName, version);
String modelZipFileUrl = mlEngine.getCIPrebuiltModelPath(modelName, version);
String configFileUrl = mlEngine.getPrebuiltModelConfigPath(modelName, version, modelFormat);
String modelZipFileUrl = mlEngine.getPrebuiltModelPath(modelName, version, modelFormat);
DownloadUtils.download(configFileUrl, configCacheFilePath, new ProgressBar());

Map<?, ?> config = null;
Expand Down Expand Up @@ -103,7 +105,7 @@ public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput
configBuilder.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.POOLING_MODE_FIELD:
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString()));
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)));
break;
case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD:
configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
Expand All @@ -36,6 +38,7 @@
import java.util.Arrays;
import java.util.UUID;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
Expand All @@ -51,6 +54,17 @@ public void setUp() {
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()));
}

@Test
public void testPrebuiltModelPath() {
String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b";
String version = "1.0.1";
MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
String prebuiltModelPath = mlEngine.getPrebuiltModelPath(modelName, version, modelFormat);
String prebuiltModelConfigPath = mlEngine.getPrebuiltModelConfigPath(modelName, version, modelFormat);
assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-torch_script.zip", prebuiltModelPath);
assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json", prebuiltModelConfigPath);
}

@Test
public void predictKMeans() {
MLModel model = trainKMeansModel();
Expand All @@ -59,7 +73,7 @@ public void predictKMeans() {
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput)mlEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(10, predictions.size());
assertEquals(10, predictions.size());
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}

Expand All @@ -71,7 +85,7 @@ public void predictLinearRegression() {
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput)mlEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(2, predictions.size());
assertEquals(2, predictions.size());
}


Expand All @@ -83,7 +97,7 @@ public void loadLinearRegressionModel() {
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
MLPredictionOutput output = (MLPredictionOutput)predictor.predict(MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build());
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(2, predictions.size());
assertEquals(2, predictions.size());
}

@Test
Expand All @@ -99,16 +113,16 @@ public void loadLinearRegressionModel_NullModel() {
@Test
public void trainKMeans() {
MLModel model = trainKMeansModel();
Assert.assertEquals(FunctionName.KMEANS.name(), model.getName());
Assert.assertEquals("1.0.0", model.getVersion());
assertEquals(FunctionName.KMEANS.name(), model.getName());
assertEquals("1.0.0", model.getVersion());
Assert.assertNotNull(model.getContent());
}

@Test
public void trainLinearRegression() {
MLModel model = trainLinearRegressionModel();
Assert.assertEquals(FunctionName.LINEAR_REGRESSION.name(), model.getName());
Assert.assertEquals("1.0.0", model.getVersion());
assertEquals(FunctionName.LINEAR_REGRESSION.name(), model.getName());
assertEquals("1.0.0", model.getVersion());
Assert.assertNotNull(model.getContent());
}

Expand Down Expand Up @@ -216,7 +230,7 @@ public void trainAndPredictWithKmeans() {
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
MLPredictionOutput output = (MLPredictionOutput) mlEngine.trainAndPredict(input);
Assert.assertEquals(dataSize, output.getPredictionResult().size());
assertEquals(dataSize, output.getPredictionResult().size());
}

@Test
Expand All @@ -231,7 +245,7 @@ public void trainAndPredictWithInvalidInput() {
public void executeLocalSampleCalculator() {
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) mlEngine.execute(input);
Assert.assertEquals(3.0, output.getResult(), 1e-5);
assertEquals(3.0, output.getResult(), 1e-5);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ public void uploadMLModel(MLUploadInput uploadInput, MLTask mlTask) {
if (uploadInput.getUrl() != null) {
uploadModelFromUrl(uploadInput, mlTask);
} else {
throw new IllegalArgumentException("model file URL is null");
// TODO: support prebuilt model later
// uploadPrebuiltModel(uploadInput, mlTask);
uploadPrebuiltModel(uploadInput, mlTask);
}
} catch (Exception e) {
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), UPLOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
Expand Down

0 comments on commit bfdc4f6

Please sign in to comment.