Skip to content

Commit

Permalink
use protostuff to serialize/deserialize RCF model (#252)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Mar 28, 2022
1 parent 770fc79 commit 88ee33f
Show file tree
Hide file tree
Showing 15 changed files with 312 additions and 51 deletions.
4 changes: 4 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ dependencies {
compile group: 'commons-io', name: 'commons-io', version: '2.11.0'
compile files('lib/randomcutforest-parkservices-2.0.1.jar')
compile files('lib/randomcutforest-core-2.0.1.jar')
compile group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0'
compile group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0'
compile group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0'
compile group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0'
testCompile group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.9.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '3.9.0'
Expand Down
Binary file modified ml-algorithms/lib/randomcutforest-core-2.0.1.jar
Binary file not shown.
Binary file modified ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.opensearch.ml.common.parameter.Model;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.utils.ModelSerDeSer;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -68,7 +67,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for batch RCF prediction.");
}
RandomCutForestState state = (RandomCutForestState) ModelSerDeSer.deserialize(model.getContent());
RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model.getContent());
RandomCutForest forest = rcfMapper.toModel(state);
List<Map<String, Object>> predictResult = process(dataFrame, forest, 0);
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
Expand All @@ -83,7 +82,7 @@ public Model train(DataFrame dataFrame) {
model.setName(FunctionName.BATCH_RCF.name());
model.setVersion(1);
RandomCutForestState state = rcfMapper.toState(forest);
model.setContent(ModelSerDeSer.serialize(state));
model.setContent(RCFModelSerDeSer.serializeRCF(state));
return model;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for FIT RCF prediction.");
}
ThresholdedRandomCutForestState state = (ThresholdedRandomCutForestState) ModelSerDeSer.deserialize(model.getContent());
ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model.getContent());
ThresholdedRandomCutForest forest = trcfMapper.toModel(state);
List<Map<String, Object>> predictResult = process(dataFrame, forest);
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
Expand All @@ -113,7 +113,7 @@ public Model train(DataFrame dataFrame) {
model.setName(FunctionName.FIT_RCF.name());
model.setVersion(1);
ThresholdedRandomCutForestState state = trcfMapper.toState(forest);
model.setContent(ModelSerDeSer.serialize(state));
model.setContent(RCFModelSerDeSer.serializeTRCF(state));
return model;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.rcf;

import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.state.RandomCutForestState;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;
import lombok.experimental.UtilityClass;

import java.security.AccessController;
import java.security.PrivilegedAction;

@UtilityClass
public class RCFModelSerDeSer {
private static final int SERIALIZATION_BUFFER_BYTES = 512;
private static final Schema<RandomCutForestState> rcfSchema =
AccessController.doPrivileged((PrivilegedAction<Schema<RandomCutForestState>>) () ->
RuntimeSchema.getSchema(RandomCutForestState.class));
private static final Schema<ThresholdedRandomCutForestState> trcfSchema =
AccessController.doPrivileged((PrivilegedAction<Schema<ThresholdedRandomCutForestState>>) () ->
RuntimeSchema.getSchema(ThresholdedRandomCutForestState.class));

public static byte[] serializeRCF(RandomCutForestState model) {
return serialize(model, rcfSchema);
}

public static byte[] serializeTRCF(ThresholdedRandomCutForestState model) {
return serialize(model, trcfSchema);
}

public static RandomCutForestState deserializeRCF(byte[] bytes) {
return deserialize(bytes, rcfSchema);
}

public static ThresholdedRandomCutForestState deserializeTRCF(byte[] bytes) {
return deserialize(bytes, trcfSchema);
}

private static <T> byte[] serialize(T model, Schema<T> schema) {
LinkedBuffer buffer = LinkedBuffer.allocate(SERIALIZATION_BUFFER_BYTES);
byte[] bytes = AccessController.doPrivileged((PrivilegedAction<byte[]>) () -> ProtostuffIOUtil.toByteArray(model, schema, buffer));
return bytes;
}

private static <T> T deserialize(byte[] bytes, Schema<T> schema) {
T model = schema.newMessage();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
ProtostuffIOUtil.mergeFrom(bytes, model, schema);
return null;
});
return model;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.ml.engine.utils;

import lombok.experimental.UtilityClass;
import org.apache.commons.io.serialization.ValidatingObjectInputStream;
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

@UtilityClass
Expand Down Expand Up @@ -41,8 +41,11 @@ public static byte[] serialize(Object model) {
}

public static Object deserialize(byte[] modelBin) {
try (ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(modelBin))) {
return objectInputStream.readObject();
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin);
ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){
// Validate the model class type to avoid deserialization attack.
validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS);
return validatingObjectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import java.io.IOException;
import java.util.Arrays;

import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class MLEngineTest {
@Rule
Expand All @@ -41,7 +41,7 @@ public class MLEngineTest {
@Test
public void predictKMeans() {
Model model = trainKMeansModel();
DataFrame predictionDataFrame = constructKMeansDataFrame(10);
DataFrame predictionDataFrame = constructTestDataFrame(10);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(mlInput, model);
Expand Down Expand Up @@ -106,7 +106,7 @@ public void train_EmptyDataFrame() {
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(0)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
Expand All @@ -118,7 +118,7 @@ public void train_UnsupportedAlgorithm() {
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
Expand All @@ -134,7 +134,7 @@ public void predictNullInput() {
public void predictWithoutAlgoName() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("algorithm can't be null");
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build();
Input mlInput = MLInput.builder().inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
Expand Down Expand Up @@ -165,7 +165,7 @@ public void predictUnsupportedAlgorithm() {
public void trainAndPredictWithKmeans() {
int dataSize = 100;
MLAlgoParams parameters = KMeansParams.builder().build();
DataFrame dataFrame = constructKMeansDataFrame(dataSize);
DataFrame dataFrame = constructTestDataFrame(dataSize);
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
Expand Down Expand Up @@ -216,7 +216,7 @@ private Model trainKMeansModel() {
.iterations(10)
.distanceType(KMeansParams.DistanceType.EUCLIDEAN)
.build();
DataFrame trainDataFrame = constructKMeansDataFrame(100);
DataFrame trainDataFrame = constructTestDataFrame(100);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
return MLEngine.train(mlInput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,39 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.parameter.KMeansParams;
import org.opensearch.ml.common.parameter.LinearRegressionParams;
import org.opensearch.ml.common.parameter.Model;
import org.opensearch.ml.engine.algorithms.clustering.KMeans;
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.tribuo.clustering.kmeans.KMeansModel;
import org.tribuo.regression.sgd.linear.LinearSGDModel;

import java.util.Arrays;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
import static org.junit.Assert.assertNotNull;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class ModelSerDeSerTest {
@Rule
public ExpectedException thrown = ExpectedException.none();

private final Object dummyModel = new Object();

@Test
public void testModelSerDeSerBlocklModel() {
thrown.expect(ModelSerDeSerException.class);
byte[] modelBin = ModelSerDeSer.serialize(dummyModel);
Object model = ModelSerDeSer.deserialize(modelBin);
assertTrue(model.equals(dummyModel));
}

@Test
public void testModelSerDeSerKMeans() {
KMeansParams params = KMeansParams.builder().build();
KMeans kMeans = new KMeans(params);
Model model = kMeans.train(constructKMeansDataFrame(100));
Model model = kMeans.train(constructTestDataFrame(100));

KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
assertFalse(Arrays.equals(serializedModel, model.getContent()));
KMeansModel deserializedModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
assertNotNull(deserializedModel);
}
}

@Test
public void testModelSerDeSerLinearRegression() {
LinearRegressionParams params = LinearRegressionParams.builder().target("f2").build();
LinearRegression linearRegression = new LinearRegression(params);
Model model = linearRegression.train(constructTestDataFrame(100));

LinearSGDModel deserializedModel = (LinearSGDModel) ModelSerDeSer.deserialize(model.getContent());
assertNotNull(deserializedModel);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import org.opensearch.ml.common.parameter.MLPredictionOutput;
import org.opensearch.ml.common.parameter.Model;

import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;

import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class KMeansTest {
@Rule
Expand Down Expand Up @@ -107,11 +106,11 @@ public void constructorWithNegtiveIterations() {
}

private void constructKMeansPredictionDataFrame() {
predictionDataFrame = constructKMeansDataFrame(predictionSize);
predictionDataFrame = constructTestDataFrame(predictionSize);
}

private void constructKMeansTrainDataFrame() {
trainDataFrame = constructKMeansDataFrame(trainSize);
trainDataFrame = constructTestDataFrame(trainSize);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.rcf;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.parameter.BatchRCFParams;
import org.opensearch.ml.common.parameter.FitRCFParams;
import org.opensearch.ml.common.parameter.Model;

import java.util.Arrays;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.engine.helper.MLTestHelper.TIME_FIELD;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class RCFModelSerDeSerTest {
@Rule
public ExpectedException thrown = ExpectedException.none();

private final RandomCutForestMapper rcfMapper = new RandomCutForestMapper();
private final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper();

@Test
public void testModelSerDeSerBatchRCF() {
BatchRCFParams params = BatchRCFParams.builder().build();
BatchRandomCutForest batchRCF = new BatchRandomCutForest(params);
Model model = batchRCF.train(constructTestDataFrame(500));

RandomCutForestState deserializedState = RCFModelSerDeSer.deserializeRCF(model.getContent());
RandomCutForest forest = rcfMapper.toModel(deserializedState);
assertNotNull(forest);
byte[] serializedModel = RCFModelSerDeSer.serializeRCF(deserializedState);
assertTrue(Arrays.equals(serializedModel, model.getContent()));
}

@Test
public void testModelSerDeSerFitRCF() {
FitRCFParams params = FitRCFParams.builder().timeField(TIME_FIELD).build();
FixedInTimeRandomCutForest fitRCF = new FixedInTimeRandomCutForest(params);
Model model = fitRCF.train(constructTestDataFrame(500, true));

ThresholdedRandomCutForestState deserializedState = RCFModelSerDeSer.deserializeTRCF(model.getContent());
ThresholdedRandomCutForest forest = trcfMapper.toModel(deserializedState);
assertNotNull(forest);
byte[] serializedModel = RCFModelSerDeSer.serializeTRCF(deserializedState);
assertTrue(Arrays.equals(serializedModel, model.getContent()));
}

}
Loading

0 comments on commit 88ee33f

Please sign in to comment.