diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 59e42a531a..cc1bc0a2ae 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -23,6 +23,10 @@ dependencies { compile group: 'commons-io', name: 'commons-io', version: '2.11.0' compile files('lib/randomcutforest-parkservices-2.0.1.jar') compile files('lib/randomcutforest-core-2.0.1.jar') + compile group: 'io.protostuff', name: 'protostuff-core', version: '1.8.0' + compile group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0' + compile group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0' + compile group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0' testCompile group: 'junit', name: 'junit', version: '4.12' testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.9.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '3.9.0' diff --git a/ml-algorithms/lib/randomcutforest-core-2.0.1.jar b/ml-algorithms/lib/randomcutforest-core-2.0.1.jar index 4b0f6b79f0..0131d21633 100644 Binary files a/ml-algorithms/lib/randomcutforest-core-2.0.1.jar and b/ml-algorithms/lib/randomcutforest-core-2.0.1.jar differ diff --git a/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar b/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar index 089a209718..692ce44f7f 100644 Binary files a/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar and b/ml-algorithms/lib/randomcutforest-parkservices-2.0.1.jar differ diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java index 2bb76d3a23..86fb0eaaac 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java @@ -22,7 +22,6 @@ import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.TrainAndPredictable; import org.opensearch.ml.engine.annotation.Function; -import org.opensearch.ml.engine.utils.ModelSerDeSer; import java.util.ArrayList; import java.util.HashMap; @@ -68,7 +67,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) { if (model == null) { throw new IllegalArgumentException("No model found for batch RCF prediction."); } - RandomCutForestState state = (RandomCutForestState) ModelSerDeSer.deserialize(model.getContent()); + RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model.getContent()); RandomCutForest forest = rcfMapper.toModel(state); List> predictResult = process(dataFrame, forest, 0); return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); @@ -83,7 +82,7 @@ public Model train(DataFrame dataFrame) { model.setName(FunctionName.BATCH_RCF.name()); model.setVersion(1); RandomCutForestState state = rcfMapper.toState(forest); - model.setContent(ModelSerDeSer.serialize(state)); + model.setContent(RCFModelSerDeSer.serializeRCF(state)); return model; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java index 8197d400ac..c6c4e62738 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java @@ -99,7 +99,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) { if (model == null) { throw new IllegalArgumentException("No model found for FIT RCF prediction."); } - ThresholdedRandomCutForestState state = (ThresholdedRandomCutForestState) ModelSerDeSer.deserialize(model.getContent()); + ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model.getContent()); ThresholdedRandomCutForest forest = trcfMapper.toModel(state); List> predictResult = process(dataFrame, forest); return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); @@ -113,7 +113,7 @@ public Model train(DataFrame dataFrame) { model.setName(FunctionName.FIT_RCF.name()); model.setVersion(1); ThresholdedRandomCutForestState state = trcfMapper.toState(forest); - model.setContent(ModelSerDeSer.serialize(state)); + model.setContent(RCFModelSerDeSer.serializeTRCF(state)); return model; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java new file mode 100644 index 0000000000..8079268297 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.rcf; + +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.state.RandomCutForestState; +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; +import io.protostuff.Schema; +import io.protostuff.runtime.RuntimeSchema; +import lombok.experimental.UtilityClass; + +import java.security.AccessController; +import java.security.PrivilegedAction; + +@UtilityClass +public class RCFModelSerDeSer { + private static final int SERIALIZATION_BUFFER_BYTES = 512; + private static final Schema rcfSchema = + AccessController.doPrivileged((PrivilegedAction>) () -> + RuntimeSchema.getSchema(RandomCutForestState.class)); + private static final Schema trcfSchema = + AccessController.doPrivileged((PrivilegedAction>) () -> + RuntimeSchema.getSchema(ThresholdedRandomCutForestState.class)); + + public static byte[] serializeRCF(RandomCutForestState model) { + return serialize(model, rcfSchema); + } + + public static byte[] serializeTRCF(ThresholdedRandomCutForestState model) { + return serialize(model, trcfSchema); + } + + public static RandomCutForestState deserializeRCF(byte[] bytes) { + return deserialize(bytes, rcfSchema); + } + + public static ThresholdedRandomCutForestState deserializeTRCF(byte[] bytes) { + return deserialize(bytes, trcfSchema); + } + + private static byte[] serialize(T model, Schema schema) { + LinkedBuffer buffer = LinkedBuffer.allocate(SERIALIZATION_BUFFER_BYTES); + byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> ProtostuffIOUtil.toByteArray(model, schema, buffer)); + return bytes; + } + + private static T deserialize(byte[] bytes, Schema schema) { + T model = schema.newMessage(); + AccessController.doPrivileged((PrivilegedAction) () -> { + ProtostuffIOUtil.mergeFrom(bytes, model, schema); + return null; + }); + return model; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java index 7b7306ebb6..4d477928b8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java @@ -6,12 +6,12 @@ package org.opensearch.ml.engine.utils; import lombok.experimental.UtilityClass; +import org.apache.commons.io.serialization.ValidatingObjectInputStream; import org.opensearch.ml.engine.exceptions.ModelSerDeSerException; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; import java.io.ObjectOutputStream; @UtilityClass @@ -41,8 +41,11 @@ public static byte[] serialize(Object model) { } public static Object deserialize(byte[] modelBin) { - try (ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(modelBin))) { - return objectInputStream.readObject(); + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin); + ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){ + // Validate the model class type to avoid deserialization attack. + validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS); + return validatingObjectInputStream.readObject(); } catch (IOException | ClassNotFoundException e) { throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause()); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 0f9a624646..eb6adc8274 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -31,9 +31,9 @@ import java.io.IOException; import java.util.Arrays; -import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; public class MLEngineTest { @Rule @@ -42,7 +42,7 @@ public class MLEngineTest { @Test public void predictKMeans() { Model model = trainKMeansModel(); - DataFrame predictionDataFrame = constructKMeansDataFrame(10); + DataFrame predictionDataFrame = constructTestDataFrame(10); MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build(); Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build(); MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(mlInput, model); @@ -107,7 +107,7 @@ public void train_EmptyDataFrame() { FunctionName algoName = FunctionName.LINEAR_REGRESSION; try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); - MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(0)).build(); MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build()); } } @@ -119,7 +119,7 @@ public void train_UnsupportedAlgorithm() { FunctionName algoName = FunctionName.LINEAR_REGRESSION; try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); - MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build(); MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build()); } } @@ -135,7 +135,7 @@ public void predictNullInput() { public void predictWithoutAlgoName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("algorithm can't be null"); - MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build(); Input mlInput = MLInput.builder().inputDataset(inputDataset).build(); MLEngine.predict(mlInput, null); } @@ -166,7 +166,7 @@ public void predictUnsupportedAlgorithm() { public void trainAndPredictWithKmeans() { int dataSize = 100; MLAlgoParams parameters = KMeansParams.builder().build(); - DataFrame dataFrame = constructKMeansDataFrame(dataSize); + DataFrame dataFrame = constructTestDataFrame(dataSize); MLInputDataset inputData = new DataFrameInputDataset(dataFrame); Input input = new MLInput(FunctionName.KMEANS, parameters, inputData); MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input); @@ -217,7 +217,7 @@ private Model trainKMeansModel() { .iterations(10) .distanceType(KMeansParams.DistanceType.EUCLIDEAN) .build(); - DataFrame trainDataFrame = constructKMeansDataFrame(100); + DataFrame trainDataFrame = constructTestDataFrame(100); MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build(); Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build(); return MLEngine.train(mlInput); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java index 47034d3dfc..6602a928be 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java @@ -9,40 +9,39 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.LinearRegressionParams; import org.opensearch.ml.common.parameter.Model; import org.opensearch.ml.engine.algorithms.clustering.KMeans; -import org.opensearch.ml.engine.exceptions.ModelSerDeSerException; +import org.opensearch.ml.engine.algorithms.regression.LinearRegression; import org.opensearch.ml.engine.utils.ModelSerDeSer; import org.tribuo.clustering.kmeans.KMeansModel; +import org.tribuo.regression.sgd.linear.LinearSGDModel; -import java.util.Arrays; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; public class ModelSerDeSerTest { @Rule public ExpectedException thrown = ExpectedException.none(); - private final Object dummyModel = new Object(); - - @Test - public void testModelSerDeSerBlocklModel() { - thrown.expect(ModelSerDeSerException.class); - byte[] modelBin = ModelSerDeSer.serialize(dummyModel); - Object model = ModelSerDeSer.deserialize(modelBin); - assertTrue(model.equals(dummyModel)); - } - @Test public void testModelSerDeSerKMeans() { KMeansParams params = KMeansParams.builder().build(); KMeans kMeans = new KMeans(params); - Model model = kMeans.train(constructKMeansDataFrame(100)); + Model model = kMeans.train(constructTestDataFrame(100)); - KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent()); - byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel); - assertFalse(Arrays.equals(serializedModel, model.getContent())); + KMeansModel deserializedModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent()); + assertNotNull(deserializedModel); } -} \ No newline at end of file + + @Test + public void testModelSerDeSerLinearRegression() { + LinearRegressionParams params = LinearRegressionParams.builder().target("f2").build(); + LinearRegression linearRegression = new LinearRegression(params); + Model model = linearRegression.train(constructTestDataFrame(100)); + + LinearSGDModel deserializedModel = (LinearSGDModel) ModelSerDeSer.deserialize(model.getContent()); + assertNotNull(deserializedModel); + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java index b6be6f0c1b..896bda2a1e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java @@ -16,8 +16,7 @@ import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.ml.common.parameter.Model; -import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; - +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; public class KMeansTest { @Rule @@ -107,11 +106,11 @@ public void constructorWithNegtiveIterations() { } private void constructKMeansPredictionDataFrame() { - predictionDataFrame = constructKMeansDataFrame(predictionSize); + predictionDataFrame = constructTestDataFrame(predictionSize); } private void constructKMeansTrainDataFrame() { - trainDataFrame = constructKMeansDataFrame(trainSize); + trainDataFrame = constructTestDataFrame(trainSize); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java new file mode 100644 index 0000000000..693f2402fa --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.rcf; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.state.RandomCutForestMapper; +import com.amazon.randomcutforest.state.RandomCutForestState; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.parameter.BatchRCFParams; +import org.opensearch.ml.common.parameter.FitRCFParams; +import org.opensearch.ml.common.parameter.Model; + +import java.util.Arrays; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.engine.helper.MLTestHelper.TIME_FIELD; +import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; + +public class RCFModelSerDeSerTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private final RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); + private final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper(); + + @Test + public void testModelSerDeSerBatchRCF() { + BatchRCFParams params = BatchRCFParams.builder().build(); + BatchRandomCutForest batchRCF = new BatchRandomCutForest(params); + Model model = batchRCF.train(constructTestDataFrame(500)); + + RandomCutForestState deserializedState = RCFModelSerDeSer.deserializeRCF(model.getContent()); + RandomCutForest forest = rcfMapper.toModel(deserializedState); + assertNotNull(forest); + byte[] serializedModel = RCFModelSerDeSer.serializeRCF(deserializedState); + assertTrue(Arrays.equals(serializedModel, model.getContent())); + } + + @Test + public void testModelSerDeSerFitRCF() { + FitRCFParams params = FitRCFParams.builder().timeField(TIME_FIELD).build(); + FixedInTimeRandomCutForest fitRCF = new FixedInTimeRandomCutForest(params); + Model model = fitRCF.train(constructTestDataFrame(500, true)); + + ThresholdedRandomCutForestState deserializedState = RCFModelSerDeSer.deserializeTRCF(model.getContent()); + ThresholdedRandomCutForest forest = trcfMapper.toModel(deserializedState); + assertNotNull(forest); + byte[] serializedModel = RCFModelSerDeSer.serializeTRCF(deserializedState); + assertTrue(Arrays.equals(serializedModel, model.getContent())); + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/KMeansHelper.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/MLTestHelper.java similarity index 58% rename from ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/KMeansHelper.java rename to ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/MLTestHelper.java index 7b45ea7f91..21ca746733 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/KMeansHelper.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/MLTestHelper.java @@ -13,13 +13,26 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; @UtilityClass -public class KMeansHelper { - public static DataFrame constructKMeansDataFrame(int size) { - ColumnMeta[] columnMetas = new ColumnMeta[]{new ColumnMeta("f1", ColumnType.DOUBLE), new ColumnMeta("f2", ColumnType.DOUBLE)}; +public class MLTestHelper { + public static final String TIME_FIELD = "timestamp"; + public static DataFrame constructTestDataFrame(int size) { + return constructTestDataFrame(size, false); + } + + public static DataFrame constructTestDataFrame(int size, boolean addTimeFiled) { + List columnMetaList = new ArrayList<>(); + columnMetaList.add(new ColumnMeta("f1", ColumnType.DOUBLE)); + columnMetaList.add(new ColumnMeta("f2", ColumnType.DOUBLE)); + if (addTimeFiled) { + columnMetaList.add(new ColumnMeta(TIME_FIELD, ColumnType.LONG)); + } + ColumnMeta[] columnMetas = columnMetaList.toArray(new ColumnMeta[0]); DataFrame dataFrame = DataFrameBuilder.emptyDataFrame(columnMetas); Random random = new Random(1); @@ -28,13 +41,19 @@ public static DataFrame constructKMeansDataFrame(int size) { MultivariateNormalDistribution g2 = new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), new double[]{10.0, 10.0}, new double[][]{{2.0, 1.0}, {1.0, 2.0}}); MultivariateNormalDistribution[] normalDistributions = new MultivariateNormalDistribution[]{g1, g2}; + long startTime = 1648154137000l; for (int i = 0; i < size; ++i) { int id = 0; if (Math.random() < 0.5) { id = 1; } double[] sample = normalDistributions[id].sample(); - dataFrame.appendRow(Arrays.stream(sample).boxed().toArray(Double[]::new)); + Object[] row = Arrays.stream(sample).boxed().toArray(Double[]::new); + if (addTimeFiled) { + long timestamp = startTime + 60_000 * i; + row = new Object[] {row[0], row[1], timestamp}; + } + dataFrame.appendRow(row); } return dataFrame; diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java new file mode 100644 index 0000000000..f652144529 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -0,0 +1,216 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action; + +import static org.opensearch.ml.utils.TestData.TARGET_FIELD; +import static org.opensearch.ml.utils.TestData.TIME_FIELD; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; + +import org.opensearch.action.ActionFuture; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnType; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.parameter.BatchRCFParams; +import org.opensearch.ml.common.parameter.FitRCFParams; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.LinearRegressionParams; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLModel; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.ml.common.parameter.MLTask; +import org.opensearch.ml.common.parameter.MLTrainingOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.model.MLModelGetAction; +import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.transport.task.MLTaskGetAction; +import org.opensearch.ml.common.transport.task.MLTaskGetRequest; +import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; +import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; +import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; +import org.opensearch.ml.plugin.MachineLearningPlugin; +import org.opensearch.ml.utils.TestData; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.gson.Gson; + +public class MLCommonsIntegTestCase extends OpenSearchIntegTestCase { + private Gson gson = new Gson(); + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(MachineLearningPlugin.class); + } + + @Override + protected Collection> transportClientPlugins() { + return Collections.singletonList(MachineLearningPlugin.class); + } + + public void loadIrisData(String indexName) { + BulkRequest bulkRequest = new BulkRequest(); + String[] rows = TestData.IRIS_DATA.split("\n"); + for (int i = 1; i < rows.length; i += 2) { + IndexRequest indexRequest = new IndexRequest(indexName).id(i + ""); + indexRequest.source(rows[i], XContentType.JSON); + bulkRequest.add(indexRequest); + } + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client().bulk(bulkRequest).actionGet(5000); + } + + public DataFrame irisDataFrame() { + DataFrame dataFrame = new DefaultDataFrame( + new ColumnMeta[] { + new ColumnMeta("petal_length_in_cm", ColumnType.DOUBLE), + new ColumnMeta("petal_width_in_cm", ColumnType.DOUBLE) } + ); + String[] rows = TestData.IRIS_DATA.split("\n"); + + for (int i = 1; i < rows.length; i += 2) { + Map map = gson.fromJson(rows[i], Map.class); + dataFrame.appendRow(new Object[] { map.get("petal_length_in_cm"), map.get("petal_width_in_cm") }); + } + return dataFrame; + } + + public SearchSourceBuilder irisDataQuery() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(1000); + searchSourceBuilder.fetchSource(new String[] { "petal_length_in_cm", "petal_width_in_cm" }, null); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + return searchSourceBuilder; + } + + public MLInputDataset emptyQueryInputDataSet(String indexName) { + SearchSourceBuilder searchSourceBuilder = irisDataQuery(); + searchSourceBuilder.query(QueryBuilders.matchQuery("class", "wrong_value")); + return new SearchQueryInputDataset(Collections.singletonList(indexName), searchSourceBuilder); + } + + public MLPredictionOutput trainAndPredictKmeansWithIrisData(String irisIndexName) { + MLInputDataset inputDataset = new SearchQueryInputDataset(ImmutableList.of(irisIndexName), irisDataQuery()); + return trainAndPredict(FunctionName.KMEANS, KMeansParams.builder().centroids(3).build(), inputDataset); + } + + public MLPredictionOutput trainAndPredictBatchRCFWithDataFrame(int dataSize) { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(dataSize)); + return trainAndPredict(FunctionName.BATCH_RCF, BatchRCFParams.builder().build(), inputDataset); + } + + public MLPredictionOutput trainAndPredict(FunctionName functionName, MLAlgoParams params, MLInputDataset inputDataset) { + MLInput mlInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputDataset).build(); + MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); + ActionFuture trainingFuture = client().execute(MLTrainAndPredictionTaskAction.INSTANCE, trainingRequest); + MLTaskResponse trainingResponse = trainingFuture.actionGet(); + assertNotNull(trainingResponse); + + MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) trainingResponse.getOutput(); + return mlPredictionOutput; + } + + public String trainKmeansWithIrisData(String irisIndexName, boolean async) { + MLInputDataset inputDataset = new SearchQueryInputDataset(ImmutableList.of(irisIndexName), irisDataQuery()); + return trainModel(FunctionName.KMEANS, KMeansParams.builder().centroids(3).build(), inputDataset, async); + } + + public String trainBatchRCFWithDataFrame(int dataSize, boolean async) { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(dataSize)); + return trainModel(FunctionName.BATCH_RCF, BatchRCFParams.builder().build(), inputDataset, async); + } + + public String trainFitRCFWithDataFrame(int dataSize, boolean async) { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(dataSize, true)); + return trainModel(FunctionName.FIT_RCF, FitRCFParams.builder().timeField(TIME_FIELD).build(), inputDataset, async); + } + + public LinearRegressionParams getLinearRegressionParams() { + return LinearRegressionParams + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.SQUARED_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.LINEAR_DECAY_SGD) + .learningRate(0.01) + .epochs(10) + .epsilon(1e-5) + .beta1(0.9) + .beta2(0.99) + .target(TARGET_FIELD) + .build(); + } + + public String trainLinearRegressionWithDataFrame(int dataSize, boolean async) { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrameForLinearRegression(dataSize)); + return trainModel(FunctionName.LINEAR_REGRESSION, getLinearRegressionParams(), inputDataset, async); + } + + public String trainModel(FunctionName functionName, MLAlgoParams params, MLInputDataset inputDataset, boolean async) { + MLInput mlInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputDataset).build(); + MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, async); + ActionFuture trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); + MLTaskResponse trainingResponse = trainingFuture.actionGet(); + assertNotNull(trainingResponse); + + MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) trainingResponse.getOutput(); + String id = async ? modelTrainingOutput.getTaskId() : modelTrainingOutput.getModelId(); + String status = modelTrainingOutput.getStatus(); + assertNotNull(id); + assertFalse(id.isEmpty()); + if (async) { + assertEquals("CREATED", status); + } else { + assertEquals("COMPLETED", status); + } + return id; + } + + public DataFrame predictAndVerify( + String modelId, + MLInputDataset inputDataset, + FunctionName functionName, + MLAlgoParams parameters, + int size + ) { + MLInput mlInput = MLInput.builder().algorithm(functionName).inputDataset(inputDataset).parameters(parameters).build(); + MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(modelId, mlInput); + ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); + MLTaskResponse predictionResponse = predictionFuture.actionGet(); + MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) predictionResponse.getOutput(); + DataFrame predictionResult = mlPredictionOutput.getPredictionResult(); + assertEquals(size, predictionResult.size()); + return predictionResult; + } + + public MLTask getTask(String taskId) { + MLTaskGetRequest getRequest = new MLTaskGetRequest(taskId); + MLTaskGetResponse response = client().execute(MLTaskGetAction.INSTANCE, getRequest).actionGet(5000); + return response.getMlTask(); + } + + public MLModel getModel(String modelId) { + MLTaskGetRequest getRequest = new MLTaskGetRequest(modelId); + MLModelGetResponse response = client().execute(MLModelGetAction.INSTANCE, getRequest).actionGet(5000); + return response.getMlModel(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java index 64ac47ff3c..f060cae5dd 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionITTests.java @@ -5,107 +5,136 @@ package org.opensearch.ml.action.prediction; -import static org.opensearch.ml.utils.IntegTestUtils.DATA_FRAME_INPUT_DATASET; -import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; -import static org.opensearch.ml.utils.IntegTestUtils.TESTING_INDEX_NAME; -import static org.opensearch.ml.utils.IntegTestUtils.generateEmptyDataset; -import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; -import static org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder; -import static org.opensearch.ml.utils.IntegTestUtils.predictAndVerifyResult; -import static org.opensearch.ml.utils.IntegTestUtils.trainModel; -import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; -import static org.opensearch.ml.utils.IntegTestUtils.waitModelAvailable; - -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.concurrent.ExecutionException; +import static org.opensearch.ml.utils.TestData.IRIS_DATA_SIZE; +import static org.opensearch.ml.utils.TestData.TIME_FIELD; +import java.util.ArrayList; +import java.util.List; + +import org.apache.lucene.util.LuceneTestCase; import org.junit.Before; -import org.junit.Ignore; -import org.opensearch.ResourceNotFoundException; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ml.action.MLCommonsIntegTestCase; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnType; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; +import org.opensearch.ml.common.dataframe.DoubleValue; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.parameter.FitRCFParams; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLModel; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import org.opensearch.ml.plugin.MachineLearningPlugin; -import org.opensearch.plugins.Plugin; -import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.ml.utils.TestData; import org.opensearch.test.OpenSearchIntegTestCase; -@OpenSearchIntegTestCase.ClusterScope(transportClientRatio = 0.9) -@Ignore("Test cases in this class are flaky, something is off with waitModelAvailable(taskId) method." - + " This issue will be tracked in an issue and will be fixed later") -public class PredictionITTests extends OpenSearchIntegTestCase { - private String taskId; - - @Before - public void initTestingData() throws ExecutionException, InterruptedException { - generateMLTestingData(); +import com.google.common.collect.ImmutableList; - SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); - MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); - taskId = trainModel(inputDataset); - waitModelAvailable(taskId); - } +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 2) +public class PredictionITTests extends MLCommonsIntegTestCase { + private String irisIndexName; + private String kMeansModelId; + private String batchRcfModelId; + private String fitRcfModelId; + private String linearRegressionModelId; + private int batchRcfDataSize = 100; - @Override - protected Collection> nodePlugins() { - return Collections.singletonList(MachineLearningPlugin.class); - } + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); - @Override - protected Collection> transportClientPlugins() { - return Collections.singletonList(MachineLearningPlugin.class); + @Before + public void setUp() throws Exception { + super.setUp(); + irisIndexName = "iris_data_for_prediction_it"; + loadIrisData(irisIndexName); + + // TODO: open these lines when this bug fix merged https://github.com/oracle/tribuo/issues/223 + // modelId = trainKmeansWithIrisData(irisIndexName, false); + // MLModel kMeansModel = getModel(kMeansModelId); + // assertNotNull(kMeansModel); + + batchRcfModelId = trainBatchRCFWithDataFrame(500, false); + fitRcfModelId = trainFitRCFWithDataFrame(500, false); + linearRegressionModelId = trainLinearRegressionWithDataFrame(100, false); + MLModel batchRcfModel = getModel(batchRcfModelId); + assertNotNull(batchRcfModel); } - public void testTestingData() throws ExecutionException, InterruptedException { - verifyGeneratedTestingData(TESTING_DATA); - waitModelAvailable(taskId); + @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + public void testPredictionWithSearchInput_KMeans() { + MLInputDataset inputDataset = new SearchQueryInputDataset(ImmutableList.of(irisIndexName), irisDataQuery()); + predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, null, IRIS_DATA_SIZE); } - public void testPredictionWithSearchInput() throws IOException { - SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); - MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); - - predictAndVerifyResult(taskId, inputDataset); + @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + public void testPredictionWithDataInput_KMeans() { + MLInputDataset inputDataset = new DataFrameInputDataset(irisDataFrame()); + predictAndVerify(kMeansModelId, inputDataset, FunctionName.KMEANS, null, IRIS_DATA_SIZE); } - public void testPredictionWithDataInput() throws IOException { - predictAndVerifyResult(taskId, DATA_FRAME_INPUT_DATASET); + @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + public void testPredictionWithoutDataset_KMeans() { + exceptionRule.expect(ActionRequestValidationException.class); + exceptionRule.expectMessage("input data can't be null"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); + MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(kMeansModelId, mlInput); + ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); + predictionFuture.actionGet(); } - public void testPredictionWithoutAlgorithm() throws IOException { - MLInput mlInput = MLInput.builder().inputDataset(DATA_FRAME_INPUT_DATASET).build(); - MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput); + @LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/oracle/tribuo/issues/223") + public void testPredictionWithEmptyDataset_KMeans() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No document found"); + MLInputDataset emptySearchInputDataset = emptyQueryInputDataSet(irisIndexName); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build(); + MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(kMeansModelId, mlInput); ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); - expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet()); + predictionFuture.actionGet(); } - public void testPredictionWithoutModelId() throws IOException { - MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(DATA_FRAME_INPUT_DATASET).build(); - MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("", mlInput); - ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); - expectThrows(ResourceNotFoundException.class, () -> predictionFuture.actionGet()); + public void testPredictionWithDataFrame_BatchRCF() { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(batchRcfDataSize)); + predictAndVerify(batchRcfModelId, inputDataset, FunctionName.BATCH_RCF, null, batchRcfDataSize); } - public void testPredictionWithoutDataset() throws IOException { - MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); - MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput); - ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); - expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet()); + public void testPredictionWithDataFrame_FitRCF() { + MLInputDataset inputDataset = new DataFrameInputDataset(TestData.constructTestDataFrame(batchRcfDataSize, true)); + predictAndVerify( + fitRcfModelId, + inputDataset, + FunctionName.FIT_RCF, + FitRCFParams.builder().timeField(TIME_FIELD).build(), + batchRcfDataSize + ); } - public void testPredictionWithEmptyDataset() throws IOException { - MLInputDataset emptySearchInputDataset = generateEmptyDataset(); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build(); - MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput); - ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); - expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet()); + public void testPredictionWithDataFrame_LinearRegression() { + int size = 1; + int feet = 20; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("feet", ColumnType.DOUBLE) }; + List rows = new ArrayList<>(); + rows.add(new Row(new ColumnValue[] { new DoubleValue(feet) })); + DataFrame inputDataFrame = new DefaultDataFrame(columnMetas, rows); + MLInputDataset inputDataset = new DataFrameInputDataset(inputDataFrame); + DataFrame dataFrame = predictAndVerify( + linearRegressionModelId, + inputDataset, + FunctionName.LINEAR_REGRESSION, + getLinearRegressionParams(), + size + ); + ColumnValue value = dataFrame.getRow(0).getValue(0); + assertNotNull(value); } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java index adc2e51beb..50bf25484b 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestData.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestData.java @@ -5,11 +5,87 @@ package org.opensearch.ml.utils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import org.apache.commons.math3.distribution.MultivariateNormalDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnType; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; +import org.opensearch.ml.common.dataframe.DoubleValue; +import org.opensearch.ml.common.dataframe.Row; + import com.google.gson.JsonArray; import com.google.gson.JsonObject; public class TestData { + public static final String TIME_FIELD = "timestamp"; + public static final String TARGET_FIELD = "price"; + + public static DataFrame constructTestDataFrame(int size) { + return constructTestDataFrame(size, false); + } + + public static DataFrame constructTestDataFrameForLinearRegression(int size) { + ColumnMeta[] columnMetas = new ColumnMeta[] { + new ColumnMeta("feet", ColumnType.DOUBLE), + new ColumnMeta(TARGET_FIELD, ColumnType.DOUBLE) }; + + List rows = new ArrayList<>(); + for (int i = 0; i < size; i++) { + rows.add(new Row(new ColumnValue[] { new DoubleValue(i), new DoubleValue(i) })); + } + return new DefaultDataFrame(columnMetas, rows); + } + + public static DataFrame constructTestDataFrame(int size, boolean addTimeFiled) { + List columnMetaList = new ArrayList<>(); + columnMetaList.add(new ColumnMeta("f1", ColumnType.DOUBLE)); + columnMetaList.add(new ColumnMeta("f2", ColumnType.DOUBLE)); + if (addTimeFiled) { + columnMetaList.add(new ColumnMeta(TIME_FIELD, ColumnType.LONG)); + } + ColumnMeta[] columnMetas = columnMetaList.toArray(new ColumnMeta[0]); + DataFrame dataFrame = DataFrameBuilder.emptyDataFrame(columnMetas); + + Random random = new Random(1); + MultivariateNormalDistribution g1 = new MultivariateNormalDistribution( + new JDKRandomGenerator(random.nextInt()), + new double[] { 0.0, 0.0 }, + new double[][] { { 2.0, 1.0 }, { 1.0, 2.0 } } + ); + MultivariateNormalDistribution g2 = new MultivariateNormalDistribution( + new JDKRandomGenerator(random.nextInt()), + new double[] { 10.0, 10.0 }, + new double[][] { { 2.0, 1.0 }, { 1.0, 2.0 } } + ); + MultivariateNormalDistribution[] normalDistributions = new MultivariateNormalDistribution[] { g1, g2 }; + long startTime = 1648154137000l; + for (int i = 0; i < size; ++i) { + int id = 0; + if (random.nextDouble() < 0.5) { + id = 1; + } + double[] sample = normalDistributions[id].sample(); + Object[] row = Arrays.stream(sample).boxed().toArray(Double[]::new); + if (addTimeFiled) { + long timestamp = startTime + 60_000 * i; + row = new Object[] { row[0], row[1], timestamp }; + } + dataFrame.appendRow(row); + } + + return dataFrame; + } + + public static final int IRIS_DATA_SIZE = 150; public static final String IRIS_DATA = "{ \"index\" : { \"_index\" : \"iris_data\" } }\n" + "{\"sepal_length_in_cm\":5.1,\"sepal_width_in_cm\":3.5,\"petal_length_in_cm\":1.4,\"petal_width_in_cm\":0.2,\"class\":\"Iris-setosa\"}\n" + "{ \"index\" : { \"_index\" : \"iris_data\" } }\n"