Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use protostuff to serialize/deserialize RCF model #251

Merged
merged 1 commit into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ dependencies {
implementation group: 'commons-io', name: 'commons-io', version: '2.11.0'
implementation files('lib/randomcutforest-parkservices-2.0.1.jar')
implementation files('lib/randomcutforest-core-2.0.1.jar')
implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0'
testImplementation group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0'
Expand Down
Binary file modified ml-algorithms/lib/randomcutforest-core-2.0.1.jar
Binary file not shown.
Binary file modified ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.utils.ModelSerDeSer;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -68,7 +67,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for batch RCF prediction.");
}
RandomCutForestState state = (RandomCutForestState) ModelSerDeSer.deserialize(model.getContent());
RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model.getContent());
RandomCutForest forest = rcfMapper.toModel(state);
List<Map<String, Object>> predictResult = process(dataFrame, forest, 0);
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
Expand All @@ -83,7 +82,7 @@ public Model train(DataFrame dataFrame) {
model.setName(FunctionName.BATCH_RCF.name());
model.setVersion(1);
RandomCutForestState state = rcfMapper.toState(forest);
model.setContent(ModelSerDeSer.serialize(state));
model.setContent(RCFModelSerDeSer.serializeRCF(state));
return model;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnType;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.utils.ModelSerDeSer;

import java.text.DateFormat;
import java.text.ParseException;
Expand Down Expand Up @@ -99,7 +98,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for FIT RCF prediction.");
}
ThresholdedRandomCutForestState state = (ThresholdedRandomCutForestState) ModelSerDeSer.deserialize(model.getContent());
ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model.getContent());
ThresholdedRandomCutForest forest = trcfMapper.toModel(state);
List<Map<String, Object>> predictResult = process(dataFrame, forest);
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
Expand All @@ -113,7 +112,7 @@ public Model train(DataFrame dataFrame) {
model.setName(FunctionName.FIT_RCF.name());
model.setVersion(1);
ThresholdedRandomCutForestState state = trcfMapper.toState(forest);
model.setContent(ModelSerDeSer.serialize(state));
model.setContent(RCFModelSerDeSer.serializeTRCF(state));
return model;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.rcf;

import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.state.RandomCutForestState;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;
import lombok.experimental.UtilityClass;

import java.security.AccessController;
import java.security.PrivilegedAction;

@UtilityClass
public class RCFModelSerDeSer {
private static final int SERIALIZATION_BUFFER_BYTES = 512;
private static final Schema<RandomCutForestState> rcfSchema =
AccessController.doPrivileged((PrivilegedAction<Schema<RandomCutForestState>>) () ->
RuntimeSchema.getSchema(RandomCutForestState.class));
private static final Schema<ThresholdedRandomCutForestState> trcfSchema =
AccessController.doPrivileged((PrivilegedAction<Schema<ThresholdedRandomCutForestState>>) () ->
RuntimeSchema.getSchema(ThresholdedRandomCutForestState.class));

public static byte[] serializeRCF(RandomCutForestState model) {
return serialize(model, rcfSchema);
}

public static byte[] serializeTRCF(ThresholdedRandomCutForestState model) {
return serialize(model, trcfSchema);
}

public static RandomCutForestState deserializeRCF(byte[] bytes) {
return deserialize(bytes, rcfSchema);
}

public static ThresholdedRandomCutForestState deserializeTRCF(byte[] bytes) {
return deserialize(bytes, trcfSchema);
}

private static <T> byte[] serialize(T model, Schema<T> schema) {
LinkedBuffer buffer = LinkedBuffer.allocate(SERIALIZATION_BUFFER_BYTES);
byte[] bytes = AccessController.doPrivileged((PrivilegedAction<byte[]>) () -> ProtostuffIOUtil.toByteArray(model, schema, buffer));
return bytes;
}

private static <T> T deserialize(byte[] bytes, Schema<T> schema) {
T model = schema.newMessage();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
ProtostuffIOUtil.mergeFrom(bytes, model, schema);
return null;
});
return model;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.ml.engine.utils;

import lombok.experimental.UtilityClass;
import org.apache.commons.io.serialization.ValidatingObjectInputStream;
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

@UtilityClass
Expand Down Expand Up @@ -41,8 +41,11 @@ public static byte[] serialize(Object model) {
}

public static Object deserialize(byte[] modelBin) {
try (ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(modelBin))) {
return objectInputStream.readObject();
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin);
ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){
// Validate the model class type to avoid deserialization attack.
validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS);
return validatingObjectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import java.io.IOException;
import java.util.Arrays;

import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class MLEngineTest {
@Rule
Expand All @@ -41,7 +41,7 @@ public class MLEngineTest {
@Test
public void predictKMeans() {
Model model = trainKMeansModel();
DataFrame predictionDataFrame = constructKMeansDataFrame(10);
DataFrame predictionDataFrame = constructTestDataFrame(10);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(mlInput, model);
Expand Down Expand Up @@ -106,7 +106,7 @@ public void train_EmptyDataFrame() {
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(0)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
Expand All @@ -118,7 +118,7 @@ public void train_UnsupportedAlgorithm() {
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
Expand All @@ -134,7 +134,7 @@ public void predictNullInput() {
public void predictWithoutAlgoName() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("algorithm can't be null");
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build();
Input mlInput = MLInput.builder().inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
Expand Down Expand Up @@ -165,7 +165,7 @@ public void predictUnsupportedAlgorithm() {
public void trainAndPredictWithKmeans() {
int dataSize = 100;
MLAlgoParams parameters = KMeansParams.builder().build();
DataFrame dataFrame = constructKMeansDataFrame(dataSize);
DataFrame dataFrame = constructTestDataFrame(dataSize);
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
Expand Down Expand Up @@ -216,7 +216,7 @@ private Model trainKMeansModel() {
.iterations(10)
.distanceType(KMeansParams.DistanceType.EUCLIDEAN)
.build();
DataFrame trainDataFrame = constructKMeansDataFrame(100);
DataFrame trainDataFrame = constructTestDataFrame(100);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
return MLEngine.train(mlInput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,40 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
import org.opensearch.ml.engine.algorithms.clustering.KMeans;
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.tribuo.clustering.kmeans.KMeansModel;
import org.tribuo.regression.sgd.linear.LinearSGDModel;

import java.util.Arrays;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
import static org.junit.Assert.assertNotNull;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class ModelSerDeSerTest {
@Rule
public ExpectedException thrown = ExpectedException.none();

private final Object dummyModel = new Object();

@Test
public void testModelSerDeSerBlocklModel() {
thrown.expect(ModelSerDeSerException.class);
byte[] modelBin = ModelSerDeSer.serialize(dummyModel);
Object model = ModelSerDeSer.deserialize(modelBin);
assertTrue(model.equals(dummyModel));
}

@Test
public void testModelSerDeSerKMeans() {
KMeansParams params = KMeansParams.builder().build();
KMeans kMeans = new KMeans(params);
Model model = kMeans.train(constructKMeansDataFrame(100));
Model model = kMeans.train(constructTestDataFrame(100));

KMeansModel deserializedModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
assertNotNull(deserializedModel);
}

KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
assertFalse(Arrays.equals(serializedModel, model.getContent()));
@Test
public void testModelSerDeSerLinearRegression() {
LinearRegressionParams params = LinearRegressionParams.builder().target("f2").build();
LinearRegression linearRegression = new LinearRegression(params);
Model model = linearRegression.train(constructTestDataFrame(100));

LinearSGDModel deserializedModel = (LinearSGDModel) ModelSerDeSer.deserialize(model.getContent());
assertNotNull(deserializedModel);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.Model;

import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;


public class KMeansTest {
Expand Down Expand Up @@ -107,11 +107,11 @@ public void constructorWithNegtiveIterations() {
}

private void constructKMeansPredictionDataFrame() {
predictionDataFrame = constructKMeansDataFrame(predictionSize);
predictionDataFrame = constructTestDataFrame(predictionSize);
}

private void constructKMeansTrainDataFrame() {
trainDataFrame = constructKMeansDataFrame(trainSize);
trainDataFrame = constructTestDataFrame(trainSize);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.rcf;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams;
import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams;

import java.util.Arrays;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.engine.helper.MLTestHelper.TIME_FIELD;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;

public class RCFModelSerDeSerTest {
@Rule
public ExpectedException thrown = ExpectedException.none();

private final RandomCutForestMapper rcfMapper = new RandomCutForestMapper();
private final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper();

@Test
public void testModelSerDeSerBatchRCF() {
BatchRCFParams params = BatchRCFParams.builder().build();
BatchRandomCutForest batchRCF = new BatchRandomCutForest(params);
Model model = batchRCF.train(constructTestDataFrame(500));

RandomCutForestState deserializedState = RCFModelSerDeSer.deserializeRCF(model.getContent());
RandomCutForest forest = rcfMapper.toModel(deserializedState);
assertNotNull(forest);
byte[] serializedModel = RCFModelSerDeSer.serializeRCF(deserializedState);
assertTrue(Arrays.equals(serializedModel, model.getContent()));
}

@Test
public void testModelSerDeSerFitRCF() {
FitRCFParams params = FitRCFParams.builder().timeField(TIME_FIELD).build();
FixedInTimeRandomCutForest fitRCF = new FixedInTimeRandomCutForest(params);
Model model = fitRCF.train(constructTestDataFrame(500, true));

ThresholdedRandomCutForestState deserializedState = RCFModelSerDeSer.deserializeTRCF(model.getContent());
ThresholdedRandomCutForest forest = trcfMapper.toModel(deserializedState);
assertNotNull(forest);
byte[] serializedModel = RCFModelSerDeSer.serializeTRCF(deserializedState);
assertTrue(Arrays.equals(serializedModel, model.getContent()));
}

}
Loading