diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 8aef8cf349..bd3e14adf4 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -23,6 +23,10 @@ dependencies { implementation group: 'commons-io', name: 'commons-io', version: '2.11.0' implementation files('lib/randomcutforest-parkservices-2.0.1.jar') implementation files('lib/randomcutforest-core-2.0.1.jar') + implementation group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' + implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0' + implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0' + implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0' testImplementation group: 'junit', name: 'junit', version: '4.12' testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0' diff --git a/ml-algorithms/lib/randomcutforest-core-2.0.1.jar b/ml-algorithms/lib/randomcutforest-core-2.0.1.jar index 4b0f6b79f0..0131d21633 100644 Binary files a/ml-algorithms/lib/randomcutforest-core-2.0.1.jar and b/ml-algorithms/lib/randomcutforest-core-2.0.1.jar differ diff --git a/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar b/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar index 089a209718..692ce44f7f 100644 Binary files a/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar and b/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar differ diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java index e09514364c..4446e816fd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java @@ -9,20 +9,19 @@ import com.amazon.randomcutforest.state.RandomCutForestMapper; import com.amazon.randomcutforest.state.RandomCutForestState; import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.Model; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; -import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.Model; import org.opensearch.ml.engine.TrainAndPredictable; import org.opensearch.ml.engine.annotation.Function; -import org.opensearch.ml.engine.utils.ModelSerDeSer; import java.util.ArrayList; import java.util.HashMap; @@ -68,7 +67,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) { if (model == null) { throw new IllegalArgumentException("No model found for batch RCF prediction."); } - RandomCutForestState state = (RandomCutForestState) ModelSerDeSer.deserialize(model.getContent()); + RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model.getContent()); RandomCutForest forest = rcfMapper.toModel(state); List> predictResult = process(dataFrame, forest, 0); return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); @@ -83,7 +82,7 @@ public Model train(DataFrame dataFrame) { model.setName(FunctionName.BATCH_RCF.name()); model.setVersion(1); RandomCutForestState state = rcfMapper.toState(forest); - model.setContent(ModelSerDeSer.serialize(state)); + model.setContent(RCFModelSerDeSer.serializeRCF(state)); return model; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java index 85ccb9c818..bd881e7beb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java @@ -12,6 +12,8 @@ import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.Model; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.ColumnValue; @@ -19,15 +21,12 @@ import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; import org.opensearch.ml.common.exception.MLValidationException; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.Model; -import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; import org.opensearch.ml.engine.TrainAndPredictable; import org.opensearch.ml.engine.annotation.Function; -import org.opensearch.ml.engine.utils.ModelSerDeSer; import java.text.DateFormat; import java.text.ParseException; @@ -99,7 +98,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) { if (model == null) { throw new IllegalArgumentException("No model found for FIT RCF prediction."); } - ThresholdedRandomCutForestState state = (ThresholdedRandomCutForestState) ModelSerDeSer.deserialize(model.getContent()); + ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model.getContent()); ThresholdedRandomCutForest forest = trcfMapper.toModel(state); List> predictResult = process(dataFrame, forest); return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); @@ -113,7 +112,7 @@ public Model train(DataFrame dataFrame) { model.setName(FunctionName.FIT_RCF.name()); model.setVersion(1); ThresholdedRandomCutForestState state = trcfMapper.toState(forest); - model.setContent(ModelSerDeSer.serialize(state)); + model.setContent(RCFModelSerDeSer.serializeTRCF(state)); return model; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java new file mode 100644 index 0000000000..8079268297 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.rcf; + +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.state.RandomCutForestState; +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; +import io.protostuff.Schema; +import io.protostuff.runtime.RuntimeSchema; +import lombok.experimental.UtilityClass; + +import java.security.AccessController; +import java.security.PrivilegedAction; + +@UtilityClass +public class RCFModelSerDeSer { + private static final int SERIALIZATION_BUFFER_BYTES = 512; + private static final Schema rcfSchema = + AccessController.doPrivileged((PrivilegedAction>) () -> + RuntimeSchema.getSchema(RandomCutForestState.class)); + private static final Schema trcfSchema = + AccessController.doPrivileged((PrivilegedAction>) () -> + RuntimeSchema.getSchema(ThresholdedRandomCutForestState.class)); + + public static byte[] serializeRCF(RandomCutForestState model) { + return serialize(model, rcfSchema); + } + + public static byte[] serializeTRCF(ThresholdedRandomCutForestState model) { + return serialize(model, trcfSchema); + } + + public static RandomCutForestState deserializeRCF(byte[] bytes) { + return deserialize(bytes, rcfSchema); + } + + public static ThresholdedRandomCutForestState deserializeTRCF(byte[] bytes) { + return deserialize(bytes, trcfSchema); + } + + private static byte[] serialize(T model, Schema schema) { + LinkedBuffer buffer = LinkedBuffer.allocate(SERIALIZATION_BUFFER_BYTES); + byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> ProtostuffIOUtil.toByteArray(model, schema, buffer)); + return bytes; + } + + private static T deserialize(byte[] bytes, Schema schema) { + T model = schema.newMessage(); + AccessController.doPrivileged((PrivilegedAction) () -> { + ProtostuffIOUtil.mergeFrom(bytes, model, schema); + return null; + }); + return model; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java index 7b7306ebb6..4d477928b8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java @@ -6,12 +6,12 @@ package org.opensearch.ml.engine.utils; import lombok.experimental.UtilityClass; +import org.apache.commons.io.serialization.ValidatingObjectInputStream; import org.opensearch.ml.engine.exceptions.ModelSerDeSerException; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; import java.io.ObjectOutputStream; @UtilityClass @@ -41,8 +41,11 @@ public static byte[] serialize(Object model) { } public static Object deserialize(byte[] modelBin) { - try (ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(modelBin))) { - return objectInputStream.readObject(); + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin); + ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){ + // Validate the model class type to avoid deserialization attack. + validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS); + return validatingObjectInputStream.readObject(); } catch (IOException | ClassNotFoundException e) { throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause()); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index e9b92f7317..aa154d62bb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -30,9 +30,9 @@ import java.io.IOException; import java.util.Arrays; -import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; public class MLEngineTest { @Rule @@ -41,7 +41,7 @@ public class MLEngineTest { @Test public void predictKMeans() { Model model = trainKMeansModel(); - DataFrame predictionDataFrame = constructKMeansDataFrame(10); + DataFrame predictionDataFrame = constructTestDataFrame(10); MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build(); Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build(); MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(mlInput, model); @@ -106,7 +106,7 @@ public void train_EmptyDataFrame() { FunctionName algoName = FunctionName.LINEAR_REGRESSION; try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); - MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(0)).build(); MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build()); } } @@ -118,7 +118,7 @@ public void train_UnsupportedAlgorithm() { FunctionName algoName = FunctionName.LINEAR_REGRESSION; try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); - MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build(); MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build()); } } @@ -134,7 +134,7 @@ public void predictNullInput() { public void predictWithoutAlgoName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("algorithm can't be null"); - MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build(); Input mlInput = MLInput.builder().inputDataset(inputDataset).build(); MLEngine.predict(mlInput, null); } @@ -165,7 +165,7 @@ public void predictUnsupportedAlgorithm() { public void trainAndPredictWithKmeans() { int dataSize = 100; MLAlgoParams parameters = KMeansParams.builder().build(); - DataFrame dataFrame = constructKMeansDataFrame(dataSize); + DataFrame dataFrame = constructTestDataFrame(dataSize); MLInputDataset inputData = new DataFrameInputDataset(dataFrame); Input input = new MLInput(FunctionName.KMEANS, parameters, inputData); MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input); @@ -216,7 +216,7 @@ private Model trainKMeansModel() { .iterations(10) .distanceType(KMeansParams.DistanceType.EUCLIDEAN) .build(); - DataFrame trainDataFrame = constructKMeansDataFrame(100); + DataFrame trainDataFrame = constructTestDataFrame(100); MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build(); Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build(); return MLEngine.train(mlInput); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java index be317a350a..ebff98035b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java @@ -8,41 +8,40 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; +import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.engine.algorithms.clustering.KMeans; -import org.opensearch.ml.engine.exceptions.ModelSerDeSerException; +import org.opensearch.ml.engine.algorithms.regression.LinearRegression; import org.opensearch.ml.engine.utils.ModelSerDeSer; import org.tribuo.clustering.kmeans.KMeansModel; +import org.tribuo.regression.sgd.linear.LinearSGDModel; -import java.util.Arrays; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; public class ModelSerDeSerTest { @Rule public ExpectedException thrown = ExpectedException.none(); - private final Object dummyModel = new Object(); - - @Test - public void testModelSerDeSerBlocklModel() { - thrown.expect(ModelSerDeSerException.class); - byte[] modelBin = ModelSerDeSer.serialize(dummyModel); - Object model = ModelSerDeSer.deserialize(modelBin); - assertTrue(model.equals(dummyModel)); - } - @Test public void testModelSerDeSerKMeans() { KMeansParams params = KMeansParams.builder().build(); KMeans kMeans = new KMeans(params); - Model model = kMeans.train(constructKMeansDataFrame(100)); + Model model = kMeans.train(constructTestDataFrame(100)); + + KMeansModel deserializedModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent()); + assertNotNull(deserializedModel); + } - KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent()); - byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel); - assertFalse(Arrays.equals(serializedModel, model.getContent())); + @Test + public void testModelSerDeSerLinearRegression() { + LinearRegressionParams params = LinearRegressionParams.builder().target("f2").build(); + LinearRegression linearRegression = new LinearRegression(params); + Model model = linearRegression.train(constructTestDataFrame(100)); + + LinearSGDModel deserializedModel = (LinearSGDModel) ModelSerDeSer.deserialize(model.getContent()); + assertNotNull(deserializedModel); } -} \ No newline at end of file + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java index 04695e277d..9ef7630f07 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java @@ -16,7 +16,7 @@ import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.Model; -import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; public class KMeansTest { @@ -107,11 +107,11 @@ public void constructorWithNegtiveIterations() { } private void constructKMeansPredictionDataFrame() { - predictionDataFrame = constructKMeansDataFrame(predictionSize); + predictionDataFrame = constructTestDataFrame(predictionSize); } private void constructKMeansTrainDataFrame() { - trainDataFrame = constructKMeansDataFrame(trainSize); + trainDataFrame = constructTestDataFrame(trainSize); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java new file mode 100644 index 0000000000..9690009505 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.rcf; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.state.RandomCutForestMapper; +import com.amazon.randomcutforest.state.RandomCutForestState; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; +import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; + +import java.util.Arrays; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.engine.helper.MLTestHelper.TIME_FIELD; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; + +public class RCFModelSerDeSerTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private final RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); + private final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper(); + + @Test + public void testModelSerDeSerBatchRCF() { + BatchRCFParams params = BatchRCFParams.builder().build(); + BatchRandomCutForest batchRCF = new BatchRandomCutForest(params); + Model model = batchRCF.train(constructTestDataFrame(500)); + + RandomCutForestState deserializedState = RCFModelSerDeSer.deserializeRCF(model.getContent()); + RandomCutForest forest = rcfMapper.toModel(deserializedState); + assertNotNull(forest); + byte[] serializedModel = RCFModelSerDeSer.serializeRCF(deserializedState); + assertTrue(Arrays.equals(serializedModel, model.getContent())); + } + + @Test + public void testModelSerDeSerFitRCF() { + FitRCFParams params = FitRCFParams.builder().timeField(TIME_FIELD).build(); + FixedInTimeRandomCutForest fitRCF = new FixedInTimeRandomCutForest(params); + Model model = fitRCF.train(constructTestDataFrame(500, true)); + + ThresholdedRandomCutForestState deserializedState = RCFModelSerDeSer.deserializeTRCF(model.getContent()); + ThresholdedRandomCutForest forest = trcfMapper.toModel(deserializedState); + assertNotNull(forest); + byte[] serializedModel = RCFModelSerDeSer.serializeTRCF(deserializedState); + assertTrue(Arrays.equals(serializedModel, model.getContent())); + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/KMeansHelper.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/MLTestHelper.java similarity index 58% rename from ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/KMeansHelper.java rename to ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/MLTestHelper.java index 7b45ea7f91..c98d2a0f3e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/KMeansHelper.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/MLTestHelper.java @@ -13,13 +13,27 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; @UtilityClass -public class KMeansHelper { - public static DataFrame constructKMeansDataFrame(int size) { - ColumnMeta[] columnMetas = new ColumnMeta[]{new ColumnMeta("f1", ColumnType.DOUBLE), new ColumnMeta("f2", ColumnType.DOUBLE)}; +public class MLTestHelper { + + public static final String TIME_FIELD = "timestamp"; + public static DataFrame constructTestDataFrame(int size) { + return constructTestDataFrame(size, false); + } + + public static DataFrame constructTestDataFrame(int size, boolean addTimeFiled) { + List columnMetaList = new ArrayList<>(); + columnMetaList.add(new ColumnMeta("f1", ColumnType.DOUBLE)); + columnMetaList.add(new ColumnMeta("f2", ColumnType.DOUBLE)); + if (addTimeFiled) { + columnMetaList.add(new ColumnMeta(TIME_FIELD, ColumnType.LONG)); + } + ColumnMeta[] columnMetas = columnMetaList.toArray(new ColumnMeta[0]); DataFrame dataFrame = DataFrameBuilder.emptyDataFrame(columnMetas); Random random = new Random(1); @@ -28,13 +42,19 @@ public static DataFrame constructKMeansDataFrame(int size) { MultivariateNormalDistribution g2 = new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), new double[]{10.0, 10.0}, new double[][]{{2.0, 1.0}, {1.0, 2.0}}); MultivariateNormalDistribution[] normalDistributions = new MultivariateNormalDistribution[]{g1, g2}; + long startTime = 1648154137000l; for (int i = 0; i < size; ++i) { int id = 0; if (Math.random() < 0.5) { id = 1; } double[] sample = normalDistributions[id].sample(); - dataFrame.appendRow(Arrays.stream(sample).boxed().toArray(Double[]::new)); + Object[] row = Arrays.stream(sample).boxed().toArray(Double[]::new); + if (addTimeFiled) { + long timestamp = startTime + 60_000 * i; + row = new Object[] {row[0], row[1], timestamp}; + } + dataFrame.appendRow(row); } return dataFrame; diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java index 40dd6b1006..d6f9628579 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -5,6 +5,9 @@ package org.opensearch.ml.action; +import static org.opensearch.ml.utils.TestData.TARGET_FIELD; +import static org.opensearch.ml.utils.TestData.TIME_FIELD; + import java.util.Collection; import java.util.Collections; import java.util.Map; @@ -29,6 +32,8 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; +import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; +import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; @@ -135,6 +140,30 @@ public String trainBatchRCFWithDataFrame(int dataSize, boolean async) { return trainModel(FunctionName.BATCH_RCF, BatchRCFParams.builder().build(), inputDataset, async); } + public String trainFitRCFWithDataFrame(int dataSize, boolean async) { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(dataSize, true)); + return trainModel(FunctionName.FIT_RCF, FitRCFParams.builder().timeField(TIME_FIELD).build(), inputDataset, async); + } + + public LinearRegressionParams getLinearRegressionParams() { + return LinearRegressionParams + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.SQUARED_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.LINEAR_DECAY_SGD) + .learningRate(0.01) + .epochs(10) + .epsilon(1e-5) + .beta1(0.9) + .beta2(0.99) + .target(TARGET_FIELD) + .build(); + } + + public String trainLinearRegressionWithDataFrame(int dataSize, boolean async) { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrameForLinearRegression(dataSize)); + return trainModel(FunctionName.LINEAR_REGRESSION, getLinearRegressionParams(), inputDataset, async); + } + public String trainModel(FunctionName functionName, MLAlgoParams params, MLInputDataset inputDataset, boolean async) { MLInput mlInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputDataset).build(); MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, async); @@ -155,8 +184,14 @@ public String trainModel(FunctionName functionName, MLAlgoParams params, MLInput return id; } - public DataFrame predictAndVerify(String modelId, MLInputDataset inputDataset, FunctionName functionName, int size) { - MLInput mlInput = MLInput.builder().algorithm(functionName).inputDataset(inputDataset).build(); + public DataFrame predictAndVerify( + String modelId, + MLInputDataset inputDataset, + FunctionName functionName, + MLAlgoParams parameters, + int size + ) { + MLInput mlInput = MLInput.builder().algorithm(functionName).inputDataset(inputDataset).parameters(parameters).build(); MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(modelId, mlInput); ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); MLTaskResponse predictionResponse = predictionFuture.actionGet(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java index 3f5734192c..a6caff2811 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java @@ -6,6 +6,10 @@ package org.opensearch.ml.action.prediction; import static org.opensearch.ml.utils.TestData.IRIS_DATA_SIZE; +import static org.opensearch.ml.utils.TestData.TIME_FIELD; + +import java.util.ArrayList; +import java.util.List; import org.apache.lucene.util.LuceneTestCase; import org.junit.Before; @@ -16,10 +20,18 @@ import org.opensearch.ml.action.MLCommonsIntegTestCase; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnType; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; +import org.opensearch.ml.common.dataframe.DoubleValue; +import org.opensearch.ml.common.dataframe.Row; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -33,6 +45,8 @@ public class PredictionITTests extends MLCommonsIntegTestCase { private String irisIndexName; private String kMeansModelId; private String batchRcfModelId; + private String fitRcfModelId; + private String linearRegressionModelId; private int batchRcfDataSize = 100; @Rule @@ -50,6 +64,8 @@ public void setUp() throws Exception { // assertNotNull(kMeansModel); batchRcfModelId = trainBatchRCFWithDataFrame(500, false); + fitRcfModelId = trainFitRCFWithDataFrame(500, false); + linearRegressionModelId = trainLinearRegressionWithDataFrame(100, false); MLModel batchRcfModel = getModel(batchRcfModelId); assertNotNull(batchRcfModel); } @@ -57,13 +73,13 @@ public void setUp() throws Exception { @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") public void testPredictionWithSearchInput_KMeans() { MLInputDataset inputDataset = new SearchQueryInputDataset(ImmutableList.of(irisIndexName), irisDataQuery()); - predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, IRIS_DATA_SIZE); + predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, null, IRIS_DATA_SIZE); } @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") public void testPredictionWithDataInput_KMeans() { MLInputDataset inputDataset = new DataFrameInputDataset(irisDataFrame()); - predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, IRIS_DATA_SIZE); + predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, null, IRIS_DATA_SIZE); } @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") @@ -89,6 +105,37 @@ public void testPredictionWithEmptyDataset_KMeans() { public void testPredictionWithDataFrame_BatchRCF() { MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(batchRcfDataSize)); - predictAndVerify(batchRcfModelId, inputDataset, FunctionName.BATCH_RCF, batchRcfDataSize); + predictAndVerify(batchRcfModelId, inputDataset, FunctionName.BATCH_RCF, null, batchRcfDataSize); + } + + public void testPredictionWithDataFrame_FitRCF() { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(batchRcfDataSize, true)); + DataFrame dataFrame = predictAndVerify( + fitRcfModelId, + inputDataset, + FunctionName.FIT_RCF, + FitRCFParams.builder().timeField(TIME_FIELD).build(), + batchRcfDataSize + ); + System.out.println(dataFrame); + } + + public void testPredictionWithDataFrame_LinearRegression() { + int size = 1; + int feet = 20; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("feet", ColumnType.DOUBLE) }; + List rows = new ArrayList<>(); + rows.add(new Row(new ColumnValue[] { new DoubleValue(feet) })); + DataFrame inputDataFrame = new DefaultDataFrame(columnMetas, rows); + MLInputDataset inputDataset = new DataFrameInputDataset(inputDataFrame); + DataFrame dataFrame = predictAndVerify( + linearRegressionModelId, + inputDataset, + FunctionName.LINEAR_REGRESSION, + getLinearRegressionParams(), + size + ); + ColumnValue value = dataFrame.getRow(0).getValue(0); + assertNotNull(value); } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java index 6aef9620dc..55bf304c26 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java @@ -5,23 +5,54 @@ package org.opensearch.ml.utils; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; import org.apache.commons.math3.distribution.MultivariateNormalDistribution; import org.apache.commons.math3.random.JDKRandomGenerator; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.ColumnType; +import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; +import org.opensearch.ml.common.dataframe.DoubleValue; +import org.opensearch.ml.common.dataframe.Row; import com.google.gson.JsonArray; import com.google.gson.JsonObject; public class TestData { + public static final String TIME_FIELD = "timestamp"; + public static final String TARGET_FIELD = "price"; + public static DataFrame constructTestDataFrame(int size) { - ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("f1", ColumnType.DOUBLE), new ColumnMeta("f2", ColumnType.DOUBLE) }; + return constructTestDataFrame(size, false); + } + + public static DataFrame constructTestDataFrameForLinearRegression(int size) { + ColumnMeta[] columnMetas = new ColumnMeta[] { + new ColumnMeta("feet", ColumnType.DOUBLE), + new ColumnMeta(TARGET_FIELD, ColumnType.DOUBLE) }; + + List rows = new ArrayList<>(); + for (int i = 0; i < size; i++) { + rows.add(new Row(new ColumnValue[] { new DoubleValue(i), new DoubleValue(i) })); + } + return new DefaultDataFrame(columnMetas, rows); + } + + public static DataFrame constructTestDataFrame(int size, boolean addTimeFiled) { + List columnMetaList = new ArrayList<>(); + columnMetaList.add(new ColumnMeta("f1", ColumnType.DOUBLE)); + columnMetaList.add(new ColumnMeta("f2", ColumnType.DOUBLE)); + if (addTimeFiled) { + columnMetaList.add(new ColumnMeta(TIME_FIELD, ColumnType.LONG)); + } + ColumnMeta[] columnMetas = columnMetaList.toArray(new ColumnMeta[0]); DataFrame dataFrame = DataFrameBuilder.emptyDataFrame(columnMetas); Random random = new Random(1); @@ -36,13 +67,19 @@ public static DataFrame constructTestDataFrame(int size) { new double[][] { { 2.0, 1.0 }, { 1.0, 2.0 } } ); MultivariateNormalDistribution[] normalDistributions = new MultivariateNormalDistribution[] { g1, g2 }; + long startTime = 1648154137000l; for (int i = 0; i < size; ++i) { int id = 0; - if (random.nextDouble() < 0.5) { + if (Math.random() < 0.5) { id = 1; } double[] sample = normalDistributions[id].sample(); - dataFrame.appendRow(Arrays.stream(sample).boxed().toArray(Double[]::new)); + Object[] row = Arrays.stream(sample).boxed().toArray(Double[]::new); + if (addTimeFiled) { + long timestamp = startTime + 60_000 * i; + row = new Object[] { row[0], row[1], timestamp }; + } + dataFrame.appendRow(row); } return dataFrame;