diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 968fa07f18..e882fc5b29 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -411,7 +411,7 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) { @Override public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); - List options = node.getOptions(); + java.util.Map options = node.getArguments(); TypeEnvironment currentEnv = context.peek(); currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER); @@ -430,7 +430,7 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) { TypeEnvironment currentEnv = context.peek(); currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); - if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) { + if (Objects.isNull(node.getArguments().get(TIME_FIELD))) { currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); } else { currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java index 34099ebbbd..5d2e32c28b 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -15,7 +16,7 @@ import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; @Getter @Setter @@ -26,7 +27,7 @@ public class Kmeans extends UnresolvedPlan { private UnresolvedPlan child; - private final List options; + private final Map arguments; @Override public UnresolvedPlan attach(UnresolvedPlan child) { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java index c4b44317dd..22771b42de 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java @@ -1,11 +1,11 @@ package org.opensearch.sql.planner.logical; import java.util.Collections; -import java.util.List; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; /** * ml-commons logical plan. @@ -16,7 +16,7 @@ public class LogicalMLCommons extends LogicalPlan { private final String algorithm; - private final List arguments; + private final Map arguments; /** * Constructor of LogicalMLCommons. @@ -25,7 +25,7 @@ public class LogicalMLCommons extends LogicalPlan { * @param arguments arguments of the algorithm */ public LogicalMLCommons(LogicalPlan child, String algorithm, - List arguments) { + Map arguments) { super(Collections.singletonList(child)); this.algorithm = algorithm; this.arguments = arguments; diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java index 3e957f1bda..ac560e86bc 100644 --- a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -2,12 +2,26 @@ public class MLCommonsConstants { + // AD constants + public static final String NUMBER_OF_TREES = "number_of_trees"; public static final String SHINGLE_SIZE = "shingle_size"; + public static final String SAMPLE_SIZE = "sample_size"; + public static final String OUTPUT_AFTER = "output_after"; public static final String TIME_DECAY = "time_decay"; + public static final String ANOMALY_RATE = "anomaly_rate"; public static final String TIME_FIELD = "time_field"; + public static final String DATE_FORMAT = "date_format"; + public static final String TIME_ZONE = "time_zone"; + public static final String TRAINING_DATA_SIZE = "training_data_size"; + public static final String ANOMALY_SCORE_THRESHOLD = "anomaly_score_threshold"; public static final String RCF_SCORE = "score"; public static final String RCF_ANOMALOUS = "anomalous"; public static final String RCF_ANOMALY_GRADE = "anomaly_grade"; public static final String RCF_TIMESTAMP = "timestamp"; + + // KMEANS constants + public static final String CENTROIDS = "centroids"; + public static final String ITERATIONS = "iterations"; + public static final String DISTANCE_TYPE = "distance_type"; } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index fde22f2485..114c71aaa5 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -701,12 +701,15 @@ public void parse_relation() { @Test public void kmeanns_relation() { + Map argumentMap = new HashMap() {{ + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(2, DataType.INTEGER)); + put("distance_type", new Literal("COSINE", DataType.STRING)); + }}; assertAnalyzeEqual( new LogicalMLCommons(LogicalPlanDSL.relation("schema"), - "kmeans", - AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))), - new Kmeans(AstDSL.relation("schema"), - AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))) + "kmeans", argumentMap), + new Kmeans(AstDSL.relation("schema"), argumentMap) ); } @@ -715,8 +718,6 @@ public void ad_batchRCF_relation() { Map argumentMap = new HashMap() {{ put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); }}; assertAnalyzeEqual( new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 1b8d606211..a2455a9a3d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -115,7 +115,11 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), "kmeans", - AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))); + ImmutableMap.builder() + .put("centroids", new Literal(3, DataType.INTEGER)) + .put("iterations", new Literal(3, DataType.DOUBLE)) + .put("distance_type", new Literal(null, DataType.STRING)) + .build()); assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/docs/user/ppl/cmd/ad.rst b/docs/user/ppl/cmd/ad.rst index 5c1b7fa618..ed30a2016d 100644 --- a/docs/user/ppl/cmd/ad.rst +++ b/docs/user/ppl/cmd/ad.rst @@ -16,20 +16,28 @@ Description Fixed In Time RCF For Time-series Data Command Syntax ===================================================== -ad +ad -* shingle_size: optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. -* time_decay: optional. It specifies how much of the recent past to consider when computing an anomaly score. The default value is 0.001. -* time_field: mandatory. It specifies the time filed for RCF to use as time-series data. +* number_of_trees(integer): optional. Number of trees in the forest. The default value is 30. +* shingle_size(integer): optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. +* sample_size(integer): optional. The sample size used by stream samplers in this forest. The default value is 256. +* output_after(integer): optional. The number of points required by stream samplers before results are returned. The default value is 32. +* time_decay(double): optional. The decay factor used by stream samplers in this forest. The default value is 0.0001. +* anomaly_rate(double): optional. The anomaly rate. The default value is 0.005. +* time_field(string): mandatory. It specifies the time filed for RCF to use as time-series data. +* date_format(string): optional. It's used for formatting time_field field. The default formatting is "yyyy-MM-dd HH:mm:ss". +* time_zone(string): optional. It's used for setting time zone for time_field filed. The default time zone is UTC. Batch RCF for Non-time-series Data Command Syntax ================================================= -ad - -* shingle_size: optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. -* time_decay: optional. It specifies how much of the recent past to consider when computing an anomaly score. The default value is 0.001. +ad +* number_of_trees(integer): optional. Number of trees in the forest. The default value is 30. +* sample_size(integer): optional. Number of random samples given to each tree from the training data set. The default value is 256. +* output_after(integer): optional. The number of points required by stream samplers before results are returned. The default value is 32. +* training_data_size(integer): optional. The default value is the size of your training data set. +* anomaly_score_threshold(double): optional. The threshold of anomaly score. The default value is 1.0. Example1: Detecting events in New York City from taxi ridership data with time-series data ========================================================================================== diff --git a/docs/user/ppl/cmd/kmeans.rst b/docs/user/ppl/cmd/kmeans.rst index a70c000b71..4608473c2c 100644 --- a/docs/user/ppl/cmd/kmeans.rst +++ b/docs/user/ppl/cmd/kmeans.rst @@ -16,9 +16,11 @@ Description Syntax ====== -kmeans +kmeans -* cluster-number: mandatory. The number of clusters you want to group your data points into. +* centroids: optional. The number of clusters you want to group your data points into. The default value is 2. +* iterations: optional. Number of iterations. The default value is 10. +* distance_type: optional. The distance type can be COSINE, L1, or EUCLIDEAN, The default type is EUCLIDEAN. Example: Clustering of Iris Dataset @@ -28,7 +30,7 @@ The example shows how to classify three Iris species (Iris setosa, Iris virginic PPL query:: - os> source=iris_data | fields sepal_length_in_cm, sepal_width_in_cm, petal_length_in_cm, petal_width_in_cm | kmeans 3 + os> source=iris_data | fields sepal_length_in_cm, sepal_width_in_cm, petal_length_in_cm, petal_width_in_cm | kmeans centroids=3 +--------------------+-------------------+--------------------+-------------------+-----------+ | sepal_length_in_cm | sepal_width_in_cm | petal_length_in_cm | petal_width_in_cm | ClusterID | |--------------------+-------------------+--------------------+-------------------+-----------| diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index acf3bbdc22..9f2a2c99ef 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -6,9 +6,17 @@ package org.opensearch.sql.opensearch.planner.physical; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; +import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; +import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; +import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; +import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import java.util.Collections; import java.util.Iterator; @@ -97,18 +105,55 @@ public List getChild() { } protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { - if (arguments.get(TIME_FIELD).getValue() == null) { + if (arguments.get(TIME_FIELD) == null) { rcfType = FunctionName.BATCH_RCF; return BatchRCFParams.builder() - .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .numberOfTrees(arguments.containsKey(NUMBER_OF_TREES) + ? ((Integer) arguments.get(NUMBER_OF_TREES).getValue()) + : null) + .sampleSize(arguments.containsKey(SAMPLE_SIZE) + ? ((Integer) arguments.get(SAMPLE_SIZE).getValue()) + : null) + .outputAfter(arguments.containsKey(OUTPUT_AFTER) + ? ((Integer) arguments.get(OUTPUT_AFTER).getValue()) + : null) + .trainingDataSize(arguments.containsKey(TRAINING_DATA_SIZE) + ? ((Integer) arguments.get(TRAINING_DATA_SIZE).getValue()) + : null) + .anomalyScoreThreshold(arguments.containsKey(ANOMALY_SCORE_THRESHOLD) + ? ((Double) arguments.get(ANOMALY_SCORE_THRESHOLD).getValue()) + : null) .build(); } rcfType = FunctionName.FIT_RCF; return FitRCFParams.builder() - .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) - .timeDecay((Double) arguments.get(TIME_DECAY).getValue()) - .timeField((String) arguments.get(TIME_FIELD).getValue()) - .dateFormat("yyyy-MM-dd HH:mm:ss") + .numberOfTrees(arguments.containsKey(NUMBER_OF_TREES) + ? ((Integer) arguments.get(NUMBER_OF_TREES).getValue()) + : null) + .shingleSize(arguments.containsKey(SHINGLE_SIZE) + ? ((Integer) arguments.get(SHINGLE_SIZE).getValue()) + : null) + .sampleSize(arguments.containsKey(SAMPLE_SIZE) + ? ((Integer) arguments.get(SAMPLE_SIZE).getValue()) + : null) + .outputAfter(arguments.containsKey(OUTPUT_AFTER) + ? ((Integer) arguments.get(OUTPUT_AFTER).getValue()) + : null) + .timeDecay(arguments.containsKey(TIME_DECAY) + ? ((Double) arguments.get(TIME_DECAY).getValue()) + : null) + .anomalyRate(arguments.containsKey(ANOMALY_RATE) + ? ((Double) arguments.get(ANOMALY_RATE).getValue()) + : null) + .timeField(arguments.containsKey(TIME_FIELD) + ? ((String) arguments.get(TIME_FIELD).getValue()) + : null) + .dateFormat(arguments.containsKey(DATE_FORMAT) + ? ((String) arguments.get(DATE_FORMAT).getValue()) + : "yyyy-MM-dd HH:mm:ss") + .timeZone(arguments.containsKey(TIME_ZONE) + ? ((String) arguments.get(TIME_ZONE).getValue()) + : null) .build(); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java index 75870b5ee1..863df58011 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -6,10 +6,14 @@ package org.opensearch.sql.opensearch.planner.physical; import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; +import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; +import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; +import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -21,6 +25,7 @@ import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -39,7 +44,7 @@ public class MLCommonsOperator extends MLCommonsOperatorActions { private final String algorithm; @Getter - private final List arguments; + private final Map arguments; @Getter private final NodeClient nodeClient; @@ -51,7 +56,7 @@ public class MLCommonsOperator extends MLCommonsOperatorActions { public void open() { super.open(); DataFrame inputDataFrame = generateInputDataset(input); - MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments, algorithm); MLPredictionOutput predictionResult = getMLPredictionResult(FunctionName.valueOf(algorithm.toUpperCase()), mlAlgoParams, inputDataFrame, nodeClient); @@ -91,15 +96,24 @@ public List getChild() { return Collections.singletonList(input); } - protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String algorithm) { + protected MLAlgoParams convertArgumentToMLParameter(Map arguments, + String algorithm) { switch (FunctionName.valueOf(algorithm.toUpperCase())) { case KMEANS: - if (argument.getValue().getValue() instanceof Number) { - return KMeansParams.builder().centroids((Integer) argument.getValue().getValue()).build(); - } else { - throw new IllegalArgumentException("unsupported Kmeans argument type:" - + argument.getValue().getType()); - } + return KMeansParams.builder() + .centroids(arguments.containsKey(CENTROIDS) + ? ((Integer) arguments.get(CENTROIDS).getValue()) + : null) + .iterations(arguments.containsKey(ITERATIONS) + ? ((Integer) arguments.get(ITERATIONS).getValue()) + : null) + .distanceType(arguments.containsKey(DISTANCE_TYPE) + ? (arguments.get(DISTANCE_TYPE).getValue() != null + ? KMeansParams.DistanceType.valueOf(( + (String) arguments.get(DISTANCE_TYPE).getValue()).toUpperCase()) + : null) + : null) + .build(); default: // TODO: update available algorithms in the message when adding a new case throw new IllegalArgumentException( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index 2427ac4fe5..5bffa1cfa8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -262,11 +262,13 @@ public void testVisitMlCommons() { NodeClient nodeClient = mock(NodeClient.class); MLCommonsOperator mlCommonsOperator = new MLCommonsOperator( - values(emptyList()), - "kmeans", - AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3))), - nodeClient - ); + values(emptyList()), "kmeans", + new HashMap() {{ + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(2, DataType.INTEGER)); + put("distance_type", new Literal(null, DataType.STRING)); + } + }, nodeClient); assertEquals(executionProtector.doProtect(mlCommonsOperator), executionProtector.visitMLCommons(mlCommonsOperator, null)); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java index 260f52770f..0d14e67f74 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -39,6 +41,7 @@ import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -63,14 +66,17 @@ public class MLCommonsOperatorTest { @BeforeEach void setUp() { - mlCommonsOperator = new MLCommonsOperator(input, "kmeans", - AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3)), - AstDSL.argument("k2", AstDSL.stringLiteral("v1")), - AstDSL.argument("k3", AstDSL.booleanLiteral(true)), - AstDSL.argument("k4", AstDSL.doubleLiteral(2.0D)), - AstDSL.argument("k5", AstDSL.shortLiteral((short)2)), - AstDSL.argument("k6", AstDSL.longLiteral(2L)), - AstDSL.argument("k7", AstDSL.floatLiteral(2F))), + Map arguments = new HashMap<>(); + arguments.put("k1",AstDSL.intLiteral(3)); + arguments.put("k2",AstDSL.stringLiteral("v1")); + arguments.put("k3",AstDSL.booleanLiteral(true)); + arguments.put("k4",AstDSL.doubleLiteral(2.0D)); + arguments.put("k5",AstDSL.shortLiteral((short)2)); + arguments.put("k6",AstDSL.longLiteral(2L)); + arguments.put("k7",AstDSL.floatLiteral(2F)); + + + mlCommonsOperator = new MLCommonsOperator(input, "kmeans", arguments, nodeClient); when(input.hasNext()).thenReturn(true).thenReturn(false); ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); @@ -122,16 +128,10 @@ public void testAccept() { @Test public void testConvertArgumentToMLParameter_UnsupportedType() { - Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("2020-10-31")); + Map argument = new HashMap<>(); + argument.put("k2",AstDSL.dateLiteral("2020-10-31")); assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator .convertArgumentToMLParameter(argument, "LINEAR_REGRESSION")); } - @Test - public void testConvertArgumentToMLParameter_KMeansUnsupportedType() { - Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("string value")); - assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator - .convertArgumentToMLParameter(argument, "KMEANS")); - } - } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 189a329de6..aee51d0a10 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -50,9 +50,19 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; +CENTROIDS: 'CENTROIDS'; +ITERATIONS: 'ITERATIONS'; +DISTANCE_TYPE: 'DISTANCE_TYPE'; +NUMBER_OF_TREES: 'NUMBER_OF_TREES'; SHINGLE_SIZE: 'SHINGLE_SIZE'; +SAMPLE_SIZE: 'SAMPLE_SIZE'; +OUTPUT_AFTER: 'OUTPUT_AFTER'; TIME_DECAY: 'TIME_DECAY'; +ANOMALY_RATE: 'ANOMALY_RATE'; TIME_FIELD: 'TIME_FIELD'; +TIME_ZONE: 'TIME_ZONE'; +TRAINING_DATA_SIZE: 'TRAINING_DATA_SIZE'; +ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index d6cd1e99b8..da37f8e22b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -89,15 +89,31 @@ parseCommand ; kmeansCommand - : KMEANS - k=integerLiteral + : KMEANS (kmeansParameter)* + ; + +kmeansParameter + : (CENTROIDS EQUAL centroids=integerLiteral) + | (ITERATIONS EQUAL iterations=integerLiteral) + | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) ; adCommand - : AD - (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)? - (TIME_DECAY EQUAL time_decay=decimalLiteral)? - (TIME_FIELD EQUAL time_field=stringLiteral)? + : AD (adParameter)* + ; + +adParameter + : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) + | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) + | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) + | (OUTPUT_AFTER EQUAL output_after=integerLiteral) + | (TIME_DECAY EQUAL time_decay=decimalLiteral) + | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) + | (TIME_FIELD EQUAL time_field=stringLiteral) + | (DATE_FORMAT EQUAL date_format=stringLiteral) + | (TIME_ZONE EQUAL time_zone=stringLiteral) + | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) + | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) ; /** clauses */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 88b61fbcb8..2b25004f15 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -23,6 +23,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WhereCommandContext; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -307,14 +308,33 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla return aggregate; } + /** + * Kmeans command. + */ @Override public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { - return new Kmeans(ArgumentFactory.getArgumentList(ctx)); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.kmeansParameter() + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + return new Kmeans(builder.build()); } + /** + * AD command. + */ @Override public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { - return new AD(ArgumentFactory.getArgumentMap(ctx)); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.adParameter() + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + + return new AD(builder.build()); } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 09cef7c911..09afd2075f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -15,22 +15,15 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; -import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; /** * Util class to get all arguments as a list from the PPL command. @@ -143,46 +136,16 @@ public static List getArgumentList(RareCommandContext ctx) { } /** - * Get list of {@link Argument}. - * - * @param ctx KmeansCommandContext instance - * @return the list of arguments fetched from the kmeans command - */ - public static List getArgumentList(KmeansCommandContext ctx) { - // TODO: add iterations and distanceType parameters for Kemans - return Collections - .singletonList(new Argument("k", getArgumentValue(ctx.k))); - } - - /** - * Get map of {@link Argument}. - * - * @param ctx ADCommandContext instance - * @return the list of arguments fetched from the AD command + * parse argument value into Literal. + * @param ctx ParserRuleContext instance + * @return Literal */ - public static Map getArgumentMap(AdCommandContext ctx) { - return new HashMap() {{ - put(SHINGLE_SIZE, (ctx.shingle_size != null) - ? getArgumentValue(ctx.shingle_size) - : new Literal(null, DataType.INTEGER)); - put(TIME_DECAY, (ctx.time_decay != null) - ? getArgumentValue(ctx.time_decay) - : new Literal(null, DataType.DOUBLE)); - put(TIME_FIELD, (ctx.time_field != null) - ? getArgumentValue(ctx.time_field) - : new Literal(null, DataType.STRING)); - } - }; - } - private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext - ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) - : ctx instanceof BooleanLiteralContext - ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) - : ctx instanceof DecimalLiteralContext - ? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE) - : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); + ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) + : ctx instanceof BooleanLiteralContext + ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 5f729e5d06..5ee0e2be6b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -39,7 +39,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; -import java.util.HashMap; +import com.google.common.collect.ImmutableMap; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -582,30 +582,68 @@ public void testParseCommand() { @Test public void testKmeansCommand() { - assertEqual("source=t | kmeans 3", - new Kmeans(relation("t"),exprList(argument("k", intLiteral(3))))); - } - - @Test - public void test_fitRCFADCommand() { - assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp'", - new AD(relation("t"),new HashMap() {{ - put("shingle_size", new Literal(10, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal("timestamp", DataType.STRING)); - } - })); + assertEqual("source=t | kmeans centroids=3 iterations=2 distance_type='l1'", + new Kmeans(relation("t"), ImmutableMap.builder() + .put("centroids", new Literal(3, DataType.INTEGER)) + .put("iterations", new Literal(2, DataType.INTEGER)) + .put("distance_type", new Literal("l1", DataType.STRING)) + .build() + )); + } + + @Test + public void testKmeansCommandWithoutParameter() { + assertEqual("source=t | kmeans", + new Kmeans(relation("t"), ImmutableMap.of())); + } + + @Test + public void test_fitRCFADCommand_withoutDataFormat() { + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' " + + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + + "number_of_trees=256 time_zone='PST' output_after=256 " + + "training_data_size=256", + new AD(relation("t"), ImmutableMap.builder() + .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) + .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) + .put("sample_size", new Literal(256, DataType.INTEGER)) + .put("number_of_trees", new Literal(256, DataType.INTEGER)) + .put("time_zone", new Literal("PST", DataType.STRING)) + .put("output_after", new Literal(256, DataType.INTEGER)) + .put("shingle_size", new Literal(10, DataType.INTEGER)) + .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) + .put("time_field", new Literal("timestamp", DataType.STRING)) + .put("training_data_size", new Literal(256, DataType.INTEGER)) + .build() + )); + } + + @Test + public void test_fitRCFADCommand_withDataFormat() { + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' " + + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + + "number_of_trees=256 time_zone='PST' output_after=256 " + + "training_data_size=256 date_format='HH:mm:ss yyyy-MM-dd'", + new AD(relation("t"), ImmutableMap.builder() + .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) + .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) + .put("sample_size", new Literal(256, DataType.INTEGER)) + .put("number_of_trees", new Literal(256, DataType.INTEGER)) + .put("date_format", new Literal("HH:mm:ss yyyy-MM-dd", DataType.STRING)) + .put("time_zone", new Literal("PST", DataType.STRING)) + .put("output_after", new Literal(256, DataType.INTEGER)) + .put("shingle_size", new Literal(10, DataType.INTEGER)) + .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) + .put("time_field", new Literal("timestamp", DataType.STRING)) + .put("training_data_size", new Literal(256, DataType.INTEGER)) + .build() + )); } @Test public void test_batchRCFADCommand() { assertEqual("source=t | AD", - new AD(relation("t"),new HashMap() {{ - put("shingle_size", new Literal(null, DataType.INTEGER)); - put("time_decay", new Literal(null, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); - } - })); + new AD(relation("t"),ImmutableMap.of())); } protected void assertEqual(String query, Node expectedPlan) {