From 79b69392d5b715c95ec99f7cc11baa9cc5526269 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Thu, 24 Mar 2022 10:08:43 -0700 Subject: [PATCH 1/5] Support more and orderless parameters for AD command Signed-off-by: jackieyanghan --- .../sql/utils/MLCommonsConstants.java | 8 ++ .../planner/physical/ADOperator.java | 20 +++- ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 7 ++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 19 +++- .../sql/ppl/utils/ArgumentFactory.java | 99 ++++++++++++++++--- .../sql/ppl/parser/AstBuilderTest.java | 45 ++++++++- 6 files changed, 180 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java index 3e957f1bda..1325884c87 100644 --- a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -2,9 +2,17 @@ public class MLCommonsConstants { + public static final String NUMBER_OF_TREES = "number_of_trees"; public static final String SHINGLE_SIZE = "shingle_size"; + public static final String SAMPLE_SIZE = "sample_size"; + public static final String OUTPUT_AFTER = "output_after"; public static final String TIME_DECAY = "time_decay"; + public static final String ANOMALY_RATE = "anomaly_rate"; public static final String TIME_FIELD = "time_field"; + public static final String DATE_FORMAT = "date_format"; + public static final String TIME_ZONE = "time_zone"; + public static final String TRAINING_DATA_SIZE = "training_data_size"; + public static final String ANOMALY_SCORE_THRESHOLD = "anomaly_score_threshold"; public static final String RCF_SCORE = "score"; public static final String RCF_ANOMALOUS = "anomalous"; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index acf3bbdc22..6dd03e6253 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -6,9 +6,17 @@ package org.opensearch.sql.opensearch.planner.physical; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; +import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; +import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; +import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; +import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import java.util.Collections; import java.util.Iterator; @@ -100,15 +108,25 @@ protected MLAlgoParams convertArgumentToMLParameter(Map argumen if (arguments.get(TIME_FIELD).getValue() == null) { rcfType = FunctionName.BATCH_RCF; return BatchRCFParams.builder() + .numberOfTrees((Integer) arguments.get(NUMBER_OF_TREES).getValue()) .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .sampleSize((Integer) arguments.get(SAMPLE_SIZE).getValue()) + .outputAfter((Integer) arguments.get(OUTPUT_AFTER).getValue()) + .trainingDataSize((Integer) arguments.get(TRAINING_DATA_SIZE).getValue()) + .anomalyScoreThreshold((Double) arguments.get(ANOMALY_SCORE_THRESHOLD).getValue()) .build(); } rcfType = FunctionName.FIT_RCF; return FitRCFParams.builder() + .numberOfTrees((Integer) arguments.get(NUMBER_OF_TREES).getValue()) .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .sampleSize((Integer) arguments.get(SAMPLE_SIZE).getValue()) + .outputAfter((Integer) arguments.get(OUTPUT_AFTER).getValue()) .timeDecay((Double) arguments.get(TIME_DECAY).getValue()) + .anomalyRate((Double) arguments.get(ANOMALY_RATE).getValue()) .timeField((String) arguments.get(TIME_FIELD).getValue()) - .dateFormat("yyyy-MM-dd HH:mm:ss") + .dateFormat((String) arguments.get(DATE_FORMAT).getValue()) + .timeZone((String) arguments.get(TIME_ZONE).getValue()) .build(); } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 189a329de6..518af6323a 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -50,9 +50,16 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; +NUMBER_OF_TREES: 'NUMBER_OF_TREES'; SHINGLE_SIZE: 'SHINGLE_SIZE'; +SAMPLE_SIZE: 'SAMPLE_SIZE'; +OUTPUT_AFTER: 'OUTPUT_AFTER'; TIME_DECAY: 'TIME_DECAY'; +ANOMALY_RATE: 'ANOMALY_RATE'; TIME_FIELD: 'TIME_FIELD'; +TIME_ZONE: 'TIME_ZONE'; +TRAINING_DATA_SIZE: 'TRAINING_DATA_SIZE'; +ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index d6cd1e99b8..a3e552ca03 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -94,10 +94,21 @@ kmeansCommand ; adCommand - : AD - (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)? - (TIME_DECAY EQUAL time_decay=decimalLiteral)? - (TIME_FIELD EQUAL time_field=stringLiteral)? + : AD (adParameter)? (COMMA adParameter)* + ; + +adParameter + : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) + | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) + | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) + | (OUTPUT_AFTER EQUAL output_after=integerLiteral) + | (TIME_DECAY EQUAL time_decay=decimalLiteral) + | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) + | (TIME_FIELD EQUAL time_field=stringLiteral) + | (DATE_FORMAT EQUAL date_format=stringLiteral) + | (TIME_ZONE EQUAL time_zone=stringLiteral) + | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) + | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) ; /** clauses */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 09cef7c911..7c50973401 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -14,15 +14,25 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RareCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; +import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; +import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; +import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; +import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; @@ -30,6 +40,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdParameterContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; /** @@ -161,18 +172,84 @@ public static List getArgumentList(KmeansCommandContext ctx) { * @return the list of arguments fetched from the AD command */ public static Map getArgumentMap(AdCommandContext ctx) { - return new HashMap() {{ - put(SHINGLE_SIZE, (ctx.shingle_size != null) - ? getArgumentValue(ctx.shingle_size) - : new Literal(null, DataType.INTEGER)); - put(TIME_DECAY, (ctx.time_decay != null) - ? getArgumentValue(ctx.time_decay) - : new Literal(null, DataType.DOUBLE)); - put(TIME_FIELD, (ctx.time_field != null) - ? getArgumentValue(ctx.time_field) - : new Literal(null, DataType.STRING)); + List adParameters = ctx.adParameter(); + Literal numberOfTrees = null; + Literal shingleSize = null; + Literal sampleSize = null; + Literal outputAfter = null; + Literal timeDecay = null; + Literal anomalyRate = null; + Literal timeField = null; + Literal dateFormat = null; + Literal timeZone = null; + Literal trainingDataSize = null; + Literal anomalyScoreThreshold = null; + + for (AdParameterContext p : adParameters) { + if (p.number_of_trees != null) { + numberOfTrees = getArgumentValue(p.number_of_trees); + } + if (p.shingle_size != null) { + shingleSize = getArgumentValue(p.shingle_size); + } + if (p.sample_size != null) { + sampleSize = getArgumentValue(p.sample_size); + } + if (p.output_after != null) { + outputAfter = getArgumentValue(p.output_after); + } + if (p.time_decay != null) { + timeDecay = getArgumentValue(p.time_decay); + } + if (p.anomaly_rate != null) { + anomalyRate = getArgumentValue(p.anomaly_rate); + } + if (p.time_field != null) { + timeField = getArgumentValue(p.time_field); } - }; + if (p.date_format != null) { + dateFormat = getArgumentValue(p.date_format); + } + if (p.time_zone != null) { + timeZone = getArgumentValue(p.time_zone); + } + if (p.training_data_size != null) { + trainingDataSize = getArgumentValue(p.training_data_size); + } + if (p.anomaly_score_threshold != null) { + anomalyScoreThreshold = getArgumentValue(p.anomaly_score_threshold); + } + } + + if (timeField != null && dateFormat == null) { + dateFormat = new Literal("yyyy-MM-dd HH:mm:ss", DataType.STRING); + } + + HashMap params = new HashMap<>(); + params.put(NUMBER_OF_TREES, numberOfTrees != null + ? numberOfTrees : new Literal(null, DataType.INTEGER)); + params.put(SHINGLE_SIZE, shingleSize != null + ? shingleSize : new Literal(null, DataType.INTEGER)); + params.put(SAMPLE_SIZE, sampleSize != null + ? sampleSize : new Literal(null, DataType.INTEGER)); + params.put(OUTPUT_AFTER, outputAfter != null + ? outputAfter : new Literal(null, DataType.INTEGER)); + params.put(TIME_DECAY, timeDecay != null + ? timeDecay : new Literal(null, DataType.DOUBLE)); + params.put(ANOMALY_RATE, anomalyRate != null + ? anomalyRate : new Literal(null, DataType.DOUBLE)); + params.put(TIME_FIELD, timeField != null + ? timeField : new Literal(null, DataType.STRING)); + params.put(DATE_FORMAT, dateFormat != null + ? dateFormat : new Literal(null, DataType.STRING)); + params.put(TIME_ZONE, timeZone != null + ? timeZone : new Literal(null, DataType.STRING)); + params.put(TRAINING_DATA_SIZE, trainingDataSize != null + ? trainingDataSize : new Literal(null, DataType.INTEGER)); + params.put(ANOMALY_SCORE_THRESHOLD, anomalyScoreThreshold != null + ? anomalyScoreThreshold : new Literal(null, DataType.DOUBLE)); + + return params; } private static Literal getArgumentValue(ParserRuleContext ctx) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 5f729e5d06..89533db287 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -587,12 +587,45 @@ public void testKmeansCommand() { } @Test - public void test_fitRCFADCommand() { - assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp'", + public void test_fitRCFADCommand_withoutDataFormat() { + assertEqual("source=t | AD shingle_size=10, time_decay=0.0001, time_field='timestamp', " + + "anomaly_rate=0.1, anomaly_score_threshold=0.1, sample_size=256, " + + "number_of_trees=256, time_zone='PST', output_after=256, " + + "training_data_size=256", new AD(relation("t"),new HashMap() {{ + put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); + put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); + put("sample_size", new Literal(256, DataType.INTEGER)); + put("number_of_trees", new Literal(256, DataType.INTEGER)); + put("date_format", new Literal("yyyy-MM-dd HH:mm:ss", DataType.STRING)); + put("time_zone", new Literal("PST", DataType.STRING)); + put("output_after", new Literal(256, DataType.INTEGER)); put("shingle_size", new Literal(10, DataType.INTEGER)); put("time_decay", new Literal(0.0001, DataType.DOUBLE)); put("time_field", new Literal("timestamp", DataType.STRING)); + put("training_data_size", new Literal(256, DataType.INTEGER)); + } + })); + } + + @Test + public void test_fitRCFADCommand_withDataFormat() { + assertEqual("source=t | AD shingle_size=10, time_decay=0.0001, time_field='timestamp', " + + "anomaly_rate=0.1, anomaly_score_threshold=0.1, sample_size=256, " + + "number_of_trees=256, time_zone='PST', output_after=256, " + + "training_data_size=256, date_format='HH:mm:ss yyyy-MM-dd'", + new AD(relation("t"),new HashMap() {{ + put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); + put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); + put("sample_size", new Literal(256, DataType.INTEGER)); + put("number_of_trees", new Literal(256, DataType.INTEGER)); + put("date_format", new Literal("HH:mm:ss yyyy-MM-dd", DataType.STRING)); + put("time_zone", new Literal("PST", DataType.STRING)); + put("output_after", new Literal(256, DataType.INTEGER)); + put("shingle_size", new Literal(10, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + put("training_data_size", new Literal(256, DataType.INTEGER)); } })); } @@ -601,6 +634,14 @@ public void test_fitRCFADCommand() { public void test_batchRCFADCommand() { assertEqual("source=t | AD", new AD(relation("t"),new HashMap() {{ + put("anomaly_rate", new Literal(null, DataType.DOUBLE)); + put("anomaly_score_threshold", new Literal(null, DataType.DOUBLE)); + put("sample_size", new Literal(null, DataType.INTEGER)); + put("number_of_trees", new Literal(null, DataType.INTEGER)); + put("date_format", new Literal(null, DataType.STRING)); + put("time_zone", new Literal(null, DataType.STRING)); + put("output_after", new Literal(null, DataType.INTEGER)); + put("training_data_size", new Literal(null, DataType.INTEGER)); put("shingle_size", new Literal(null, DataType.INTEGER)); put("time_decay", new Literal(null, DataType.DOUBLE)); put("time_field", new Literal(null, DataType.STRING)); From 02ec7257477d750784186e667e592de078a97146 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Fri, 25 Mar 2022 11:14:03 -0700 Subject: [PATCH 2/5] Support more parameters for KMEANS command, and update AD and KMEANS documentation Signed-off-by: jackieyanghan --- .../org/opensearch/sql/analysis/Analyzer.java | 2 +- .../org/opensearch/sql/ast/tree/Kmeans.java | 5 ++- .../sql/planner/logical/LogicalMLCommons.java | 8 ++-- .../sql/utils/MLCommonsConstants.java | 6 +++ .../opensearch/sql/analysis/AnalyzerTest.java | 11 +++-- .../logical/LogicalPlanNodeVisitorTest.java | 7 ++- docs/user/ppl/cmd/ad.rst | 24 +++++++---- docs/user/ppl/cmd/kmeans.rst | 8 ++-- .../planner/physical/ADOperator.java | 1 - .../planner/physical/MLCommonsOperator.java | 26 +++++++---- .../OpenSearchExecutionProtectorTest.java | 12 +++--- .../physical/MLCommonsOperatorTest.java | 32 +++++++------- ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 3 ++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 9 +++- .../opensearch/sql/ppl/parser/AstBuilder.java | 2 +- .../sql/ppl/utils/ArgumentFactory.java | 43 +++++++++++++++---- .../sql/ppl/parser/AstBuilderTest.java | 22 ++++++++-- 17 files changed, 153 insertions(+), 68 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 968fa07f18..270887aa53 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -411,7 +411,7 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) { @Override public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); - List options = node.getOptions(); + java.util.Map options = node.getArguments(); TypeEnvironment currentEnv = context.peek(); currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER); diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java index 34099ebbbd..5d2e32c28b 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -15,7 +16,7 @@ import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; @Getter @Setter @@ -26,7 +27,7 @@ public class Kmeans extends UnresolvedPlan { private UnresolvedPlan child; - private final List options; + private final Map arguments; @Override public UnresolvedPlan attach(UnresolvedPlan child) { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java index c4b44317dd..22771b42de 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java @@ -1,11 +1,11 @@ package org.opensearch.sql.planner.logical; import java.util.Collections; -import java.util.List; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; /** * ml-commons logical plan. @@ -16,7 +16,7 @@ public class LogicalMLCommons extends LogicalPlan { private final String algorithm; - private final List arguments; + private final Map arguments; /** * Constructor of LogicalMLCommons. @@ -25,7 +25,7 @@ public class LogicalMLCommons extends LogicalPlan { * @param arguments arguments of the algorithm */ public LogicalMLCommons(LogicalPlan child, String algorithm, - List arguments) { + Map arguments) { super(Collections.singletonList(child)); this.algorithm = algorithm; this.arguments = arguments; diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java index 1325884c87..ac560e86bc 100644 --- a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -2,6 +2,7 @@ public class MLCommonsConstants { + // AD constants public static final String NUMBER_OF_TREES = "number_of_trees"; public static final String SHINGLE_SIZE = "shingle_size"; public static final String SAMPLE_SIZE = "sample_size"; @@ -18,4 +19,9 @@ public class MLCommonsConstants { public static final String RCF_ANOMALOUS = "anomalous"; public static final String RCF_ANOMALY_GRADE = "anomaly_grade"; public static final String RCF_TIMESTAMP = "timestamp"; + + // KMEANS constants + public static final String CENTROIDS = "centroids"; + public static final String ITERATIONS = "iterations"; + public static final String DISTANCE_TYPE = "distance_type"; } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index fde22f2485..ad0455c505 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -701,12 +701,15 @@ public void parse_relation() { @Test public void kmeanns_relation() { + Map argumentMap = new HashMap() {{ + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(2, DataType.INTEGER)); + put("distance_type", new Literal("COSINE", DataType.STRING)); + }}; assertAnalyzeEqual( new LogicalMLCommons(LogicalPlanDSL.relation("schema"), - "kmeans", - AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))), - new Kmeans(AstDSL.relation("schema"), - AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))) + "kmeans", argumentMap), + new Kmeans(AstDSL.relation("schema"), argumentMap) ); } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 1b8d606211..1b42092dac 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -115,7 +115,12 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), "kmeans", - AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))); + new HashMap() {{ + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(3, DataType.DOUBLE)); + put("distance_type", new Literal(null, DataType.STRING)); + } + }); assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/docs/user/ppl/cmd/ad.rst b/docs/user/ppl/cmd/ad.rst index 5c1b7fa618..172a744df7 100644 --- a/docs/user/ppl/cmd/ad.rst +++ b/docs/user/ppl/cmd/ad.rst @@ -16,20 +16,28 @@ Description Fixed In Time RCF For Time-series Data Command Syntax ===================================================== -ad +ad , , , , , , , , -* shingle_size: optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. -* time_decay: optional. It specifies how much of the recent past to consider when computing an anomaly score. The default value is 0.001. -* time_field: mandatory. It specifies the time filed for RCF to use as time-series data. +* number_of_trees(integer): optional. Number of trees in the forest. The default value is 30. +* shingle_size(integer): optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. +* sample_size(integer): optional. The sample size used by stream samplers in this forest. The default value is 256. +* output_after(integer): optional. The number of points required by stream samplers before results are returned. The default value is 32. +* time_decay(double): optional. The decay factor used by stream samplers in this forest. The default value is 0.0001. +* anomaly_rate(double): optional. The anomaly rate. The default value is 0.005. +* time_field(string): mandatory. It specifies the time filed for RCF to use as time-series data. +* date_format(string): optional. It's used for formatting time_field field. The default formatting is "yyyy-MM-dd HH:mm:ss". +* time_zone(string): optional. It's used for setting time zone for time_field filed. The default time zone is UTC. Batch RCF for Non-time-series Data Command Syntax ================================================= -ad - -* shingle_size: optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. -* time_decay: optional. It specifies how much of the recent past to consider when computing an anomaly score. The default value is 0.001. +ad , , , , +* number_of_trees(integer): optional. Number of trees in the forest. The default value is 30. +* sample_size(integer): optional. Number of random samples given to each tree from the training data set. The default value is 256. +* output_after(integer): optional. The number of points required by stream samplers before results are returned. The default value is 32. +* training_data_size(integer): optional. The default value is the size of your training data set. +* anomaly_score_threshold(double): optional. The threshold of anomaly score. The default value is 1.0. Example1: Detecting events in New York City from taxi ridership data with time-series data ========================================================================================== diff --git a/docs/user/ppl/cmd/kmeans.rst b/docs/user/ppl/cmd/kmeans.rst index a70c000b71..d790713ecf 100644 --- a/docs/user/ppl/cmd/kmeans.rst +++ b/docs/user/ppl/cmd/kmeans.rst @@ -16,9 +16,11 @@ Description Syntax ====== -kmeans +kmeans , -* cluster-number: mandatory. The number of clusters you want to group your data points into. +* centroids: optional. The number of clusters you want to group your data points into. The default value is 2. +* iterations: optional. Number of iterations. The default value is 10. +* distance_type: optional. The distance type can be COSINE, L1, or EUCLIDEAN, The default type is EUCLIDEAN. Example: Clustering of Iris Dataset @@ -28,7 +30,7 @@ The example shows how to classify three Iris species (Iris setosa, Iris virginic PPL query:: - os> source=iris_data | fields sepal_length_in_cm, sepal_width_in_cm, petal_length_in_cm, petal_width_in_cm | kmeans 3 + os> source=iris_data | fields sepal_length_in_cm, sepal_width_in_cm, petal_length_in_cm, petal_width_in_cm | kmeans centroids=3 +--------------------+-------------------+--------------------+-------------------+-----------+ | sepal_length_in_cm | sepal_width_in_cm | petal_length_in_cm | petal_width_in_cm | ClusterID | |--------------------+-------------------+--------------------+-------------------+-----------| diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index 6dd03e6253..337db50260 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -109,7 +109,6 @@ protected MLAlgoParams convertArgumentToMLParameter(Map argumen rcfType = FunctionName.BATCH_RCF; return BatchRCFParams.builder() .numberOfTrees((Integer) arguments.get(NUMBER_OF_TREES).getValue()) - .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) .sampleSize((Integer) arguments.get(SAMPLE_SIZE).getValue()) .outputAfter((Integer) arguments.get(OUTPUT_AFTER).getValue()) .trainingDataSize((Integer) arguments.get(TRAINING_DATA_SIZE).getValue()) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java index 75870b5ee1..45e8c36632 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -6,10 +6,14 @@ package org.opensearch.sql.opensearch.planner.physical; import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; +import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; +import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; +import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -21,6 +25,7 @@ import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -39,7 +44,7 @@ public class MLCommonsOperator extends MLCommonsOperatorActions { private final String algorithm; @Getter - private final List arguments; + private final Map arguments; @Getter private final NodeClient nodeClient; @@ -51,7 +56,7 @@ public class MLCommonsOperator extends MLCommonsOperatorActions { public void open() { super.open(); DataFrame inputDataFrame = generateInputDataset(input); - MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments, algorithm); MLPredictionOutput predictionResult = getMLPredictionResult(FunctionName.valueOf(algorithm.toUpperCase()), mlAlgoParams, inputDataFrame, nodeClient); @@ -91,15 +96,18 @@ public List getChild() { return Collections.singletonList(input); } - protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String algorithm) { + protected MLAlgoParams convertArgumentToMLParameter(Map arguments, + String algorithm) { switch (FunctionName.valueOf(algorithm.toUpperCase())) { case KMEANS: - if (argument.getValue().getValue() instanceof Number) { - return KMeansParams.builder().centroids((Integer) argument.getValue().getValue()).build(); - } else { - throw new IllegalArgumentException("unsupported Kmeans argument type:" - + argument.getValue().getType()); - } + return KMeansParams.builder() + .centroids((Integer) arguments.get(CENTROIDS).getValue()) + .iterations((Integer) arguments.get(ITERATIONS).getValue()) + .distanceType(arguments.get(DISTANCE_TYPE).getValue() != null + ? KMeansParams.DistanceType.valueOf(( + (String) arguments.get(DISTANCE_TYPE).getValue()).toUpperCase()) + : null) + .build(); default: // TODO: update available algorithms in the message when adding a new case throw new IllegalArgumentException( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index 2427ac4fe5..5bffa1cfa8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -262,11 +262,13 @@ public void testVisitMlCommons() { NodeClient nodeClient = mock(NodeClient.class); MLCommonsOperator mlCommonsOperator = new MLCommonsOperator( - values(emptyList()), - "kmeans", - AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3))), - nodeClient - ); + values(emptyList()), "kmeans", + new HashMap() {{ + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(2, DataType.INTEGER)); + put("distance_type", new Literal(null, DataType.STRING)); + } + }, nodeClient); assertEquals(executionProtector.doProtect(mlCommonsOperator), executionProtector.visitMLCommons(mlCommonsOperator, null)); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java index 260f52770f..0d14e67f74 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -39,6 +41,7 @@ import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -63,14 +66,17 @@ public class MLCommonsOperatorTest { @BeforeEach void setUp() { - mlCommonsOperator = new MLCommonsOperator(input, "kmeans", - AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3)), - AstDSL.argument("k2", AstDSL.stringLiteral("v1")), - AstDSL.argument("k3", AstDSL.booleanLiteral(true)), - AstDSL.argument("k4", AstDSL.doubleLiteral(2.0D)), - AstDSL.argument("k5", AstDSL.shortLiteral((short)2)), - AstDSL.argument("k6", AstDSL.longLiteral(2L)), - AstDSL.argument("k7", AstDSL.floatLiteral(2F))), + Map arguments = new HashMap<>(); + arguments.put("k1",AstDSL.intLiteral(3)); + arguments.put("k2",AstDSL.stringLiteral("v1")); + arguments.put("k3",AstDSL.booleanLiteral(true)); + arguments.put("k4",AstDSL.doubleLiteral(2.0D)); + arguments.put("k5",AstDSL.shortLiteral((short)2)); + arguments.put("k6",AstDSL.longLiteral(2L)); + arguments.put("k7",AstDSL.floatLiteral(2F)); + + + mlCommonsOperator = new MLCommonsOperator(input, "kmeans", arguments, nodeClient); when(input.hasNext()).thenReturn(true).thenReturn(false); ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); @@ -122,16 +128,10 @@ public void testAccept() { @Test public void testConvertArgumentToMLParameter_UnsupportedType() { - Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("2020-10-31")); + Map argument = new HashMap<>(); + argument.put("k2",AstDSL.dateLiteral("2020-10-31")); assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator .convertArgumentToMLParameter(argument, "LINEAR_REGRESSION")); } - @Test - public void testConvertArgumentToMLParameter_KMeansUnsupportedType() { - Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("string value")); - assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator - .convertArgumentToMLParameter(argument, "KMEANS")); - } - } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 518af6323a..aee51d0a10 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -50,6 +50,9 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; +CENTROIDS: 'CENTROIDS'; +ITERATIONS: 'ITERATIONS'; +DISTANCE_TYPE: 'DISTANCE_TYPE'; NUMBER_OF_TREES: 'NUMBER_OF_TREES'; SHINGLE_SIZE: 'SHINGLE_SIZE'; SAMPLE_SIZE: 'SAMPLE_SIZE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index a3e552ca03..d6abe1159b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -89,8 +89,13 @@ parseCommand ; kmeansCommand - : KMEANS - k=integerLiteral + : KMEANS (kmeansParameter)? (COMMA kmeansParameter)* + ; + +kmeansParameter + : (CENTROIDS EQUAL centroids=integerLiteral) + | (ITERATIONS EQUAL iterations=integerLiteral) + | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) ; adCommand diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 88b61fbcb8..e0469e3db1 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -309,7 +309,7 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla @Override public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { - return new Kmeans(ArgumentFactory.getArgumentList(ctx)); + return new Kmeans(ArgumentFactory.getArgumentMap(ctx)); } @Override diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 7c50973401..ba58f3fcf2 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -18,7 +18,10 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; +import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; +import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; +import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; @@ -39,9 +42,11 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdParameterContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansParameterContext; /** * Util class to get all arguments as a list from the PPL command. @@ -154,22 +159,44 @@ public static List getArgumentList(RareCommandContext ctx) { } /** - * Get list of {@link Argument}. + * Get a map of {@link Argument}. * * @param ctx KmeansCommandContext instance - * @return the list of arguments fetched from the kmeans command + * @return a map of arguments fetched from the kmeans command */ - public static List getArgumentList(KmeansCommandContext ctx) { - // TODO: add iterations and distanceType parameters for Kemans - return Collections - .singletonList(new Argument("k", getArgumentValue(ctx.k))); + public static Map getArgumentMap(KmeansCommandContext ctx) { + List kmeansParameters = ctx.kmeansParameter(); + Literal centroids = null; + Literal iterations = null; + Literal distanceType = null; + + for (KmeansParameterContext p : kmeansParameters) { + if (p.centroids != null) { + centroids = getArgumentValue(p.centroids); + } + if (p.iterations != null) { + iterations = getArgumentValue(p.iterations); + } + if (p.distance_type != null) { + distanceType = getArgumentValue(p.distance_type); + } + } + + Map params = new HashMap<>(); + params.put(CENTROIDS, centroids != null + ? centroids : new Literal(null, DataType.INTEGER)); + params.put(ITERATIONS, iterations != null + ? iterations : new Literal(null, DataType.INTEGER)); + params.put(DISTANCE_TYPE, distanceType != null + ? distanceType : new Literal(null, DataType.STRING)); + return params; } /** - * Get map of {@link Argument}. + * Get a map of {@link Argument}. * * @param ctx ADCommandContext instance - * @return the list of arguments fetched from the AD command + * @return a map of arguments fetched from the AD command */ public static Map getArgumentMap(AdCommandContext ctx) { List adParameters = ctx.adParameter(); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 89533db287..5fb67ed5d8 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -582,8 +582,24 @@ public void testParseCommand() { @Test public void testKmeansCommand() { - assertEqual("source=t | kmeans 3", - new Kmeans(relation("t"),exprList(argument("k", intLiteral(3))))); + assertEqual("source=t | kmeans centroids=3, iterations=2, distance_type='l1'", + new Kmeans(relation("t"), new HashMap() {{ + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(2, DataType.INTEGER)); + put("distance_type", new Literal("l1", DataType.STRING)); + } + })); + } + + @Test + public void testKmeansCommandWithoutParameter() { + assertEqual("source=t | kmeans", + new Kmeans(relation("t"), new HashMap() {{ + put("centroids", new Literal(null, DataType.INTEGER)); + put("iterations", new Literal(null, DataType.INTEGER)); + put("distance_type", new Literal(null, DataType.STRING)); + } + })); } @Test @@ -592,7 +608,7 @@ public void test_fitRCFADCommand_withoutDataFormat() { + "anomaly_rate=0.1, anomaly_score_threshold=0.1, sample_size=256, " + "number_of_trees=256, time_zone='PST', output_after=256, " + "training_data_size=256", - new AD(relation("t"),new HashMap() {{ + new AD(relation("t"), new HashMap() {{ put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); put("sample_size", new Literal(256, DataType.INTEGER)); From 94f4fcaa46881da8d8fe9ce5de7f30a64b5bc1ad Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Tue, 29 Mar 2022 21:26:10 -0700 Subject: [PATCH 3/5] Apply labels on Kmeans and AD parameters Signed-off-by: jackieyanghan --- .../org/opensearch/sql/analysis/Analyzer.java | 2 +- .../opensearch/sql/analysis/AnalyzerTest.java | 2 - docs/user/ppl/cmd/ad.rst | 4 +- docs/user/ppl/cmd/kmeans.rst | 2 +- .../planner/physical/ADOperator.java | 58 +++++-- .../planner/physical/MLCommonsOperator.java | 12 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 32 ++-- .../opensearch/sql/ppl/parser/AstBuilder.java | 21 ++- .../sql/ppl/parser/AstExpressionBuilder.java | 105 +++++++++++++ .../sql/ppl/utils/ArgumentFactory.java | 147 +----------------- .../sql/ppl/parser/AstBuilderTest.java | 53 ++++--- 11 files changed, 226 insertions(+), 212 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 270887aa53..e882fc5b29 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -430,7 +430,7 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) { TypeEnvironment currentEnv = context.peek(); currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); - if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) { + if (Objects.isNull(node.getArguments().get(TIME_FIELD))) { currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); } else { currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index ad0455c505..114c71aaa5 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -718,8 +718,6 @@ public void ad_batchRCF_relation() { Map argumentMap = new HashMap() {{ put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); }}; assertAnalyzeEqual( new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), diff --git a/docs/user/ppl/cmd/ad.rst b/docs/user/ppl/cmd/ad.rst index 172a744df7..ed30a2016d 100644 --- a/docs/user/ppl/cmd/ad.rst +++ b/docs/user/ppl/cmd/ad.rst @@ -16,7 +16,7 @@ Description Fixed In Time RCF For Time-series Data Command Syntax ===================================================== -ad , , , , , , , , +ad * number_of_trees(integer): optional. Number of trees in the forest. The default value is 30. * shingle_size(integer): optional. A shingle is a consecutive sequence of the most recent records. The default value is 8. @@ -31,7 +31,7 @@ ad , , , , , , , , +ad * number_of_trees(integer): optional. Number of trees in the forest. The default value is 30. * sample_size(integer): optional. Number of random samples given to each tree from the training data set. The default value is 256. diff --git a/docs/user/ppl/cmd/kmeans.rst b/docs/user/ppl/cmd/kmeans.rst index d790713ecf..4608473c2c 100644 --- a/docs/user/ppl/cmd/kmeans.rst +++ b/docs/user/ppl/cmd/kmeans.rst @@ -16,7 +16,7 @@ Description Syntax ====== -kmeans , +kmeans * centroids: optional. The number of clusters you want to group your data points into. The default value is 2. * iterations: optional. Number of iterations. The default value is 10. diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index 337db50260..9f2a2c99ef 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -105,27 +105,55 @@ public List getChild() { } protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { - if (arguments.get(TIME_FIELD).getValue() == null) { + if (arguments.get(TIME_FIELD) == null) { rcfType = FunctionName.BATCH_RCF; return BatchRCFParams.builder() - .numberOfTrees((Integer) arguments.get(NUMBER_OF_TREES).getValue()) - .sampleSize((Integer) arguments.get(SAMPLE_SIZE).getValue()) - .outputAfter((Integer) arguments.get(OUTPUT_AFTER).getValue()) - .trainingDataSize((Integer) arguments.get(TRAINING_DATA_SIZE).getValue()) - .anomalyScoreThreshold((Double) arguments.get(ANOMALY_SCORE_THRESHOLD).getValue()) + .numberOfTrees(arguments.containsKey(NUMBER_OF_TREES) + ? ((Integer) arguments.get(NUMBER_OF_TREES).getValue()) + : null) + .sampleSize(arguments.containsKey(SAMPLE_SIZE) + ? ((Integer) arguments.get(SAMPLE_SIZE).getValue()) + : null) + .outputAfter(arguments.containsKey(OUTPUT_AFTER) + ? ((Integer) arguments.get(OUTPUT_AFTER).getValue()) + : null) + .trainingDataSize(arguments.containsKey(TRAINING_DATA_SIZE) + ? ((Integer) arguments.get(TRAINING_DATA_SIZE).getValue()) + : null) + .anomalyScoreThreshold(arguments.containsKey(ANOMALY_SCORE_THRESHOLD) + ? ((Double) arguments.get(ANOMALY_SCORE_THRESHOLD).getValue()) + : null) .build(); } rcfType = FunctionName.FIT_RCF; return FitRCFParams.builder() - .numberOfTrees((Integer) arguments.get(NUMBER_OF_TREES).getValue()) - .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) - .sampleSize((Integer) arguments.get(SAMPLE_SIZE).getValue()) - .outputAfter((Integer) arguments.get(OUTPUT_AFTER).getValue()) - .timeDecay((Double) arguments.get(TIME_DECAY).getValue()) - .anomalyRate((Double) arguments.get(ANOMALY_RATE).getValue()) - .timeField((String) arguments.get(TIME_FIELD).getValue()) - .dateFormat((String) arguments.get(DATE_FORMAT).getValue()) - .timeZone((String) arguments.get(TIME_ZONE).getValue()) + .numberOfTrees(arguments.containsKey(NUMBER_OF_TREES) + ? ((Integer) arguments.get(NUMBER_OF_TREES).getValue()) + : null) + .shingleSize(arguments.containsKey(SHINGLE_SIZE) + ? ((Integer) arguments.get(SHINGLE_SIZE).getValue()) + : null) + .sampleSize(arguments.containsKey(SAMPLE_SIZE) + ? ((Integer) arguments.get(SAMPLE_SIZE).getValue()) + : null) + .outputAfter(arguments.containsKey(OUTPUT_AFTER) + ? ((Integer) arguments.get(OUTPUT_AFTER).getValue()) + : null) + .timeDecay(arguments.containsKey(TIME_DECAY) + ? ((Double) arguments.get(TIME_DECAY).getValue()) + : null) + .anomalyRate(arguments.containsKey(ANOMALY_RATE) + ? ((Double) arguments.get(ANOMALY_RATE).getValue()) + : null) + .timeField(arguments.containsKey(TIME_FIELD) + ? ((String) arguments.get(TIME_FIELD).getValue()) + : null) + .dateFormat(arguments.containsKey(DATE_FORMAT) + ? ((String) arguments.get(DATE_FORMAT).getValue()) + : "yyyy-MM-dd HH:mm:ss") + .timeZone(arguments.containsKey(TIME_ZONE) + ? ((String) arguments.get(TIME_ZONE).getValue()) + : null) .build(); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java index 45e8c36632..863df58011 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -101,12 +101,18 @@ protected MLAlgoParams convertArgumentToMLParameter(Map argumen switch (FunctionName.valueOf(algorithm.toUpperCase())) { case KMEANS: return KMeansParams.builder() - .centroids((Integer) arguments.get(CENTROIDS).getValue()) - .iterations((Integer) arguments.get(ITERATIONS).getValue()) - .distanceType(arguments.get(DISTANCE_TYPE).getValue() != null + .centroids(arguments.containsKey(CENTROIDS) + ? ((Integer) arguments.get(CENTROIDS).getValue()) + : null) + .iterations(arguments.containsKey(ITERATIONS) + ? ((Integer) arguments.get(ITERATIONS).getValue()) + : null) + .distanceType(arguments.containsKey(DISTANCE_TYPE) + ? (arguments.get(DISTANCE_TYPE).getValue() != null ? KMeansParams.DistanceType.valueOf(( (String) arguments.get(DISTANCE_TYPE).getValue()).toUpperCase()) : null) + : null) .build(); default: // TODO: update available algorithms in the message when adding a new case diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index d6abe1159b..b4124b0c73 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -89,31 +89,31 @@ parseCommand ; kmeansCommand - : KMEANS (kmeansParameter)? (COMMA kmeansParameter)* + : KMEANS (kmeansParameter)* ; kmeansParameter - : (CENTROIDS EQUAL centroids=integerLiteral) - | (ITERATIONS EQUAL iterations=integerLiteral) - | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) + : (CENTROIDS EQUAL centroids=integerLiteral) #centroids + | (ITERATIONS EQUAL iterations=integerLiteral) #iterations + | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) #distance_type ; adCommand - : AD (adParameter)? (COMMA adParameter)* + : AD (adParameter)* ; adParameter - : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) - | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) - | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) - | (OUTPUT_AFTER EQUAL output_after=integerLiteral) - | (TIME_DECAY EQUAL time_decay=decimalLiteral) - | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) - | (TIME_FIELD EQUAL time_field=stringLiteral) - | (DATE_FORMAT EQUAL date_format=stringLiteral) - | (TIME_ZONE EQUAL time_zone=stringLiteral) - | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) - | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) + : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) #number_of_trees + | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) #shingle_size + | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) #sample_size + | (OUTPUT_AFTER EQUAL output_after=integerLiteral) #output_after + | (TIME_DECAY EQUAL time_decay=decimalLiteral) #time_decay + | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) #anomaly_rate + | (TIME_FIELD EQUAL time_field=stringLiteral) #time_field + | (DATE_FORMAT EQUAL date_format=stringLiteral) #date_format + | (TIME_ZONE EQUAL time_zone=stringLiteral) #time_zone + | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) #training_data_size + | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) #anomaly_score_threshold ; /** clauses */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index e0469e3db1..e7cc81ed71 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -32,6 +32,7 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -307,14 +308,30 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla return aggregate; } + /** + * Kmeans command. + */ @Override public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { - return new Kmeans(ArgumentFactory.getArgumentMap(ctx)); + return new Kmeans(ctx.kmeansParameter().stream() + .map(p -> (Argument) internalVisitExpression(p)) + .collect(Collectors.toMap( + Argument::getArgName, Argument::getValue, + (value1, value2) -> value2) + )); } + /** + * AD command. + */ @Override public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { - return new AD(ArgumentFactory.getArgumentMap(ctx)); + return new AD(ctx.adParameter().stream() + .map(p -> (Argument) internalVisitExpression(p)) + .collect(Collectors.toMap( + Argument::getArgName, Argument::getValue, + (value1, value2) -> value2) + )); } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 79612ff2cb..c7f24e20b6 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -9,15 +9,20 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_rateContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_score_thresholdContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CentroidsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Date_formatContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Distance_typeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; @@ -27,19 +32,43 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IterationsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Number_of_treesContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Output_afterContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ParentheticBinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PercentileAggFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RelevanceExpressionContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Sample_sizeContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Shingle_sizeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_decayContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_fieldContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_zoneContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Training_data_sizeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext; +import static org.opensearch.sql.ppl.utils.ArgumentFactory.getArgumentValue; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; +import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; +import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; +import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; +import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; +import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; +import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; +import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; +import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -100,6 +129,82 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); } + /** + * Kmeans arguments. + */ + @Override + public UnresolvedExpression visitCentroids(CentroidsContext ctx) { + return new Argument(CENTROIDS, getArgumentValue(ctx.centroids)); + } + + @Override + public UnresolvedExpression visitIterations(IterationsContext ctx) { + return new Argument(ITERATIONS, getArgumentValue(ctx.iterations)); + } + + @Override + public UnresolvedExpression visitDistance_type(Distance_typeContext ctx) { + return new Argument(DISTANCE_TYPE, getArgumentValue(ctx.distance_type)); + } + + /** + * AD arguments. + */ + @Override + public UnresolvedExpression visitNumber_of_trees(Number_of_treesContext ctx) { + return new Argument(NUMBER_OF_TREES, getArgumentValue(ctx.number_of_trees)); + } + + @Override + public UnresolvedExpression visitShingle_size(Shingle_sizeContext ctx) { + return new Argument(SHINGLE_SIZE, getArgumentValue(ctx.shingle_size)); + } + + @Override + public UnresolvedExpression visitSample_size(Sample_sizeContext ctx) { + return new Argument(SAMPLE_SIZE, getArgumentValue(ctx.sample_size)); + } + + @Override + public UnresolvedExpression visitOutput_after(Output_afterContext ctx) { + return new Argument(OUTPUT_AFTER, getArgumentValue(ctx.output_after)); + } + + @Override + public UnresolvedExpression visitTime_decay(Time_decayContext ctx) { + return new Argument(TIME_DECAY, getArgumentValue(ctx.time_decay)); + } + + @Override + public UnresolvedExpression visitAnomaly_rate(Anomaly_rateContext ctx) { + return new Argument(ANOMALY_RATE, getArgumentValue(ctx.anomaly_rate)); + } + + @Override + public UnresolvedExpression visitTime_field(Time_fieldContext ctx) { + return new Argument(TIME_FIELD, getArgumentValue(ctx.time_field)); + } + + @Override + public UnresolvedExpression visitDate_format(Date_formatContext ctx) { + return new Argument(DATE_FORMAT, getArgumentValue(ctx.date_format)); + } + + @Override + public UnresolvedExpression visitTime_zone(Time_zoneContext ctx) { + return new Argument(TIME_ZONE, getArgumentValue(ctx.time_zone)); + } + + @Override + public UnresolvedExpression visitTraining_data_size(Training_data_sizeContext ctx) { + return new Argument(TRAINING_DATA_SIZE, getArgumentValue(ctx.training_data_size)); + } + + @Override + public UnresolvedExpression visitAnomaly_score_threshold(Anomaly_score_thresholdContext ctx) { + return new Argument(ANOMALY_SCORE_THRESHOLD, getArgumentValue(ctx.anomaly_score_threshold)); + } + /** * Logical expression excluding boolean, comparison. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index ba58f3fcf2..a635013b1e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -14,39 +14,16 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RareCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; -import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; -import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; -import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; -import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; -import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; -import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; -import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; -import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; -import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; -import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Locale; -import java.util.Map; import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdParameterContext; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansParameterContext; /** * Util class to get all arguments as a list from the PPL command. @@ -159,127 +136,11 @@ public static List getArgumentList(RareCommandContext ctx) { } /** - * Get a map of {@link Argument}. - * - * @param ctx KmeansCommandContext instance - * @return a map of arguments fetched from the kmeans command - */ - public static Map getArgumentMap(KmeansCommandContext ctx) { - List kmeansParameters = ctx.kmeansParameter(); - Literal centroids = null; - Literal iterations = null; - Literal distanceType = null; - - for (KmeansParameterContext p : kmeansParameters) { - if (p.centroids != null) { - centroids = getArgumentValue(p.centroids); - } - if (p.iterations != null) { - iterations = getArgumentValue(p.iterations); - } - if (p.distance_type != null) { - distanceType = getArgumentValue(p.distance_type); - } - } - - Map params = new HashMap<>(); - params.put(CENTROIDS, centroids != null - ? centroids : new Literal(null, DataType.INTEGER)); - params.put(ITERATIONS, iterations != null - ? iterations : new Literal(null, DataType.INTEGER)); - params.put(DISTANCE_TYPE, distanceType != null - ? distanceType : new Literal(null, DataType.STRING)); - return params; - } - - /** - * Get a map of {@link Argument}. - * - * @param ctx ADCommandContext instance - * @return a map of arguments fetched from the AD command + * parse argument value into Literal. + * @param ctx ParserRuleContext instance + * @return Literal */ - public static Map getArgumentMap(AdCommandContext ctx) { - List adParameters = ctx.adParameter(); - Literal numberOfTrees = null; - Literal shingleSize = null; - Literal sampleSize = null; - Literal outputAfter = null; - Literal timeDecay = null; - Literal anomalyRate = null; - Literal timeField = null; - Literal dateFormat = null; - Literal timeZone = null; - Literal trainingDataSize = null; - Literal anomalyScoreThreshold = null; - - for (AdParameterContext p : adParameters) { - if (p.number_of_trees != null) { - numberOfTrees = getArgumentValue(p.number_of_trees); - } - if (p.shingle_size != null) { - shingleSize = getArgumentValue(p.shingle_size); - } - if (p.sample_size != null) { - sampleSize = getArgumentValue(p.sample_size); - } - if (p.output_after != null) { - outputAfter = getArgumentValue(p.output_after); - } - if (p.time_decay != null) { - timeDecay = getArgumentValue(p.time_decay); - } - if (p.anomaly_rate != null) { - anomalyRate = getArgumentValue(p.anomaly_rate); - } - if (p.time_field != null) { - timeField = getArgumentValue(p.time_field); - } - if (p.date_format != null) { - dateFormat = getArgumentValue(p.date_format); - } - if (p.time_zone != null) { - timeZone = getArgumentValue(p.time_zone); - } - if (p.training_data_size != null) { - trainingDataSize = getArgumentValue(p.training_data_size); - } - if (p.anomaly_score_threshold != null) { - anomalyScoreThreshold = getArgumentValue(p.anomaly_score_threshold); - } - } - - if (timeField != null && dateFormat == null) { - dateFormat = new Literal("yyyy-MM-dd HH:mm:ss", DataType.STRING); - } - - HashMap params = new HashMap<>(); - params.put(NUMBER_OF_TREES, numberOfTrees != null - ? numberOfTrees : new Literal(null, DataType.INTEGER)); - params.put(SHINGLE_SIZE, shingleSize != null - ? shingleSize : new Literal(null, DataType.INTEGER)); - params.put(SAMPLE_SIZE, sampleSize != null - ? sampleSize : new Literal(null, DataType.INTEGER)); - params.put(OUTPUT_AFTER, outputAfter != null - ? outputAfter : new Literal(null, DataType.INTEGER)); - params.put(TIME_DECAY, timeDecay != null - ? timeDecay : new Literal(null, DataType.DOUBLE)); - params.put(ANOMALY_RATE, anomalyRate != null - ? anomalyRate : new Literal(null, DataType.DOUBLE)); - params.put(TIME_FIELD, timeField != null - ? timeField : new Literal(null, DataType.STRING)); - params.put(DATE_FORMAT, dateFormat != null - ? dateFormat : new Literal(null, DataType.STRING)); - params.put(TIME_ZONE, timeZone != null - ? timeZone : new Literal(null, DataType.STRING)); - params.put(TRAINING_DATA_SIZE, trainingDataSize != null - ? trainingDataSize : new Literal(null, DataType.INTEGER)); - params.put(ANOMALY_SCORE_THRESHOLD, anomalyScoreThreshold != null - ? anomalyScoreThreshold : new Literal(null, DataType.DOUBLE)); - - return params; - } - - private static Literal getArgumentValue(ParserRuleContext ctx) { + public static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) : ctx instanceof BooleanLiteralContext diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 5fb67ed5d8..40f64aafad 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -582,7 +582,7 @@ public void testParseCommand() { @Test public void testKmeansCommand() { - assertEqual("source=t | kmeans centroids=3, iterations=2, distance_type='l1'", + assertEqual("source=t | kmeans centroids=3 iterations=2 distance_type='l1'", new Kmeans(relation("t"), new HashMap() {{ put("centroids", new Literal(3, DataType.INTEGER)); put("iterations", new Literal(2, DataType.INTEGER)); @@ -591,29 +591,41 @@ public void testKmeansCommand() { })); } + @Test + public void testKmeansCommand_withDuplicateParameters() { + assertEqual("source=t | kmeans centroids=3 centroids=2", + new Kmeans(relation("t"), new HashMap() {{ + put("centroids", new Literal(2, DataType.INTEGER)); + } + })); + } + @Test public void testKmeansCommandWithoutParameter() { assertEqual("source=t | kmeans", - new Kmeans(relation("t"), new HashMap() {{ - put("centroids", new Literal(null, DataType.INTEGER)); - put("iterations", new Literal(null, DataType.INTEGER)); - put("distance_type", new Literal(null, DataType.STRING)); + new Kmeans(relation("t"), new HashMap() {})); + } + + @Test + public void test_fitRCFADCommand_withDuplicateParameters() { + assertEqual("source=t | AD shingle_size=10 shingle_size=8", + new AD(relation("t"), new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); } })); } @Test public void test_fitRCFADCommand_withoutDataFormat() { - assertEqual("source=t | AD shingle_size=10, time_decay=0.0001, time_field='timestamp', " - + "anomaly_rate=0.1, anomaly_score_threshold=0.1, sample_size=256, " - + "number_of_trees=256, time_zone='PST', output_after=256, " + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' " + + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + + "number_of_trees=256 time_zone='PST' output_after=256 " + "training_data_size=256", new AD(relation("t"), new HashMap() {{ put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); put("sample_size", new Literal(256, DataType.INTEGER)); put("number_of_trees", new Literal(256, DataType.INTEGER)); - put("date_format", new Literal("yyyy-MM-dd HH:mm:ss", DataType.STRING)); put("time_zone", new Literal("PST", DataType.STRING)); put("output_after", new Literal(256, DataType.INTEGER)); put("shingle_size", new Literal(10, DataType.INTEGER)); @@ -626,10 +638,10 @@ public void test_fitRCFADCommand_withoutDataFormat() { @Test public void test_fitRCFADCommand_withDataFormat() { - assertEqual("source=t | AD shingle_size=10, time_decay=0.0001, time_field='timestamp', " - + "anomaly_rate=0.1, anomaly_score_threshold=0.1, sample_size=256, " - + "number_of_trees=256, time_zone='PST', output_after=256, " - + "training_data_size=256, date_format='HH:mm:ss yyyy-MM-dd'", + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' " + + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + + "number_of_trees=256 time_zone='PST' output_after=256 " + + "training_data_size=256 date_format='HH:mm:ss yyyy-MM-dd'", new AD(relation("t"),new HashMap() {{ put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); @@ -649,20 +661,7 @@ public void test_fitRCFADCommand_withDataFormat() { @Test public void test_batchRCFADCommand() { assertEqual("source=t | AD", - new AD(relation("t"),new HashMap() {{ - put("anomaly_rate", new Literal(null, DataType.DOUBLE)); - put("anomaly_score_threshold", new Literal(null, DataType.DOUBLE)); - put("sample_size", new Literal(null, DataType.INTEGER)); - put("number_of_trees", new Literal(null, DataType.INTEGER)); - put("date_format", new Literal(null, DataType.STRING)); - put("time_zone", new Literal(null, DataType.STRING)); - put("output_after", new Literal(null, DataType.INTEGER)); - put("training_data_size", new Literal(null, DataType.INTEGER)); - put("shingle_size", new Literal(null, DataType.INTEGER)); - put("time_decay", new Literal(null, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); - } - })); + new AD(relation("t"),new HashMap() { })); } protected void assertEqual(String query, Node expectedPlan) { From 465bdecbcd7148d3202953ad8d18dfb10151ac5e Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Thu, 31 Mar 2022 09:49:53 -0700 Subject: [PATCH 4/5] Remove HashMap usage Signed-off-by: jackieyanghan --- .../logical/LogicalPlanNodeVisitorTest.java | 11 ++- .../physical/MLCommonsOperatorActions.java | 10 +-- .../sql/ppl/parser/AstBuilderTest.java | 88 +++++++++---------- 3 files changed, 53 insertions(+), 56 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 1b42092dac..a2455a9a3d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -115,12 +115,11 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), "kmeans", - new HashMap() {{ - put("centroids", new Literal(3, DataType.INTEGER)); - put("iterations", new Literal(3, DataType.DOUBLE)); - put("distance_type", new Literal(null, DataType.STRING)); - } - }); + ImmutableMap.builder() + .put("centroids", new Literal(3, DataType.INTEGER)) + .put("iterations", new Literal(3, DataType.DOUBLE)) + .put("distance_type", new Literal(null, DataType.STRING)) + .build()); assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java index 21b232c031..3574630452 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -7,7 +7,6 @@ package org.opensearch.sql.opensearch.planner.physical; import com.google.common.collect.ImmutableMap; -import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -49,12 +48,11 @@ public abstract class MLCommonsOperatorActions extends PhysicalPlan { */ protected DataFrame generateInputDataset(PhysicalPlan input) { List> inputData = new LinkedList<>(); + ImmutableMap.Builder inputDataBuilder = new ImmutableMap.Builder<>(); while (input.hasNext()) { - inputData.add(new HashMap() { - { - input.next().tupleValue().forEach((key, value) -> put(key, value.value())); - } - }); + input.next().tupleValue().forEach((key, value) + -> inputDataBuilder.put(key, value.value())); + inputData.add(inputDataBuilder.build()); } return DataFrameBuilder.load(inputData); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 40f64aafad..453ef4490e 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -39,7 +39,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; -import java.util.HashMap; +import com.google.common.collect.ImmutableMap; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -583,36 +583,36 @@ public void testParseCommand() { @Test public void testKmeansCommand() { assertEqual("source=t | kmeans centroids=3 iterations=2 distance_type='l1'", - new Kmeans(relation("t"), new HashMap() {{ - put("centroids", new Literal(3, DataType.INTEGER)); - put("iterations", new Literal(2, DataType.INTEGER)); - put("distance_type", new Literal("l1", DataType.STRING)); - } - })); + new Kmeans(relation("t"), ImmutableMap.builder() + .put("centroids", new Literal(3, DataType.INTEGER)) + .put("iterations", new Literal(2, DataType.INTEGER)) + .put("distance_type", new Literal("l1", DataType.STRING)) + .build() + )); } @Test public void testKmeansCommand_withDuplicateParameters() { assertEqual("source=t | kmeans centroids=3 centroids=2", - new Kmeans(relation("t"), new HashMap() {{ - put("centroids", new Literal(2, DataType.INTEGER)); - } - })); + new Kmeans(relation("t"), ImmutableMap.builder() + .put("centroids", new Literal(2, DataType.INTEGER)) + .build() + )); } @Test public void testKmeansCommandWithoutParameter() { assertEqual("source=t | kmeans", - new Kmeans(relation("t"), new HashMap() {})); + new Kmeans(relation("t"), ImmutableMap.of())); } @Test public void test_fitRCFADCommand_withDuplicateParameters() { assertEqual("source=t | AD shingle_size=10 shingle_size=8", - new AD(relation("t"), new HashMap() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - } - })); + new AD(relation("t"), ImmutableMap.builder() + .put("shingle_size", new Literal(8, DataType.INTEGER)) + .build() + )); } @Test @@ -621,19 +621,19 @@ public void test_fitRCFADCommand_withoutDataFormat() { + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + "number_of_trees=256 time_zone='PST' output_after=256 " + "training_data_size=256", - new AD(relation("t"), new HashMap() {{ - put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); - put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); - put("sample_size", new Literal(256, DataType.INTEGER)); - put("number_of_trees", new Literal(256, DataType.INTEGER)); - put("time_zone", new Literal("PST", DataType.STRING)); - put("output_after", new Literal(256, DataType.INTEGER)); - put("shingle_size", new Literal(10, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal("timestamp", DataType.STRING)); - put("training_data_size", new Literal(256, DataType.INTEGER)); - } - })); + new AD(relation("t"), ImmutableMap.builder() + .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) + .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) + .put("sample_size", new Literal(256, DataType.INTEGER)) + .put("number_of_trees", new Literal(256, DataType.INTEGER)) + .put("time_zone", new Literal("PST", DataType.STRING)) + .put("output_after", new Literal(256, DataType.INTEGER)) + .put("shingle_size", new Literal(10, DataType.INTEGER)) + .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) + .put("time_field", new Literal("timestamp", DataType.STRING)) + .put("training_data_size", new Literal(256, DataType.INTEGER)) + .build() + )); } @Test @@ -642,26 +642,26 @@ public void test_fitRCFADCommand_withDataFormat() { + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + "number_of_trees=256 time_zone='PST' output_after=256 " + "training_data_size=256 date_format='HH:mm:ss yyyy-MM-dd'", - new AD(relation("t"),new HashMap() {{ - put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)); - put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)); - put("sample_size", new Literal(256, DataType.INTEGER)); - put("number_of_trees", new Literal(256, DataType.INTEGER)); - put("date_format", new Literal("HH:mm:ss yyyy-MM-dd", DataType.STRING)); - put("time_zone", new Literal("PST", DataType.STRING)); - put("output_after", new Literal(256, DataType.INTEGER)); - put("shingle_size", new Literal(10, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal("timestamp", DataType.STRING)); - put("training_data_size", new Literal(256, DataType.INTEGER)); - } - })); + new AD(relation("t"), ImmutableMap.builder() + .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) + .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) + .put("sample_size", new Literal(256, DataType.INTEGER)) + .put("number_of_trees", new Literal(256, DataType.INTEGER)) + .put("date_format", new Literal("HH:mm:ss yyyy-MM-dd", DataType.STRING)) + .put("time_zone", new Literal("PST", DataType.STRING)) + .put("output_after", new Literal(256, DataType.INTEGER)) + .put("shingle_size", new Literal(10, DataType.INTEGER)) + .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) + .put("time_field", new Literal("timestamp", DataType.STRING)) + .put("training_data_size", new Literal(256, DataType.INTEGER)) + .build() + )); } @Test public void test_batchRCFADCommand() { assertEqual("source=t | AD", - new AD(relation("t"),new HashMap() { })); + new AD(relation("t"),ImmutableMap.of())); } protected void assertEqual(String query, Node expectedPlan) { From 9c38c3c97de9bc6539c95c44d166984a3d395e68 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Thu, 7 Apr 2022 17:35:36 -0700 Subject: [PATCH 5/5] Remove labels and visitors for AD and KMEANS command Signed-off-by: jackieyanghan --- .../physical/MLCommonsOperatorActions.java | 10 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 28 ++--- .../opensearch/sql/ppl/parser/AstBuilder.java | 29 ++--- .../sql/ppl/parser/AstExpressionBuilder.java | 105 ------------------ .../sql/ppl/utils/ArgumentFactory.java | 12 +- .../sql/ppl/parser/AstBuilderTest.java | 18 --- 6 files changed, 41 insertions(+), 161 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java index 3574630452..21b232c031 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -7,6 +7,7 @@ package org.opensearch.sql.opensearch.planner.physical; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -48,11 +49,12 @@ public abstract class MLCommonsOperatorActions extends PhysicalPlan { */ protected DataFrame generateInputDataset(PhysicalPlan input) { List> inputData = new LinkedList<>(); - ImmutableMap.Builder inputDataBuilder = new ImmutableMap.Builder<>(); while (input.hasNext()) { - input.next().tupleValue().forEach((key, value) - -> inputDataBuilder.put(key, value.value())); - inputData.add(inputDataBuilder.build()); + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) -> put(key, value.value())); + } + }); } return DataFrameBuilder.load(inputData); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index b4124b0c73..da37f8e22b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -93,9 +93,9 @@ kmeansCommand ; kmeansParameter - : (CENTROIDS EQUAL centroids=integerLiteral) #centroids - | (ITERATIONS EQUAL iterations=integerLiteral) #iterations - | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) #distance_type + : (CENTROIDS EQUAL centroids=integerLiteral) + | (ITERATIONS EQUAL iterations=integerLiteral) + | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) ; adCommand @@ -103,17 +103,17 @@ adCommand ; adParameter - : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) #number_of_trees - | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) #shingle_size - | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) #sample_size - | (OUTPUT_AFTER EQUAL output_after=integerLiteral) #output_after - | (TIME_DECAY EQUAL time_decay=decimalLiteral) #time_decay - | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) #anomaly_rate - | (TIME_FIELD EQUAL time_field=stringLiteral) #time_field - | (DATE_FORMAT EQUAL date_format=stringLiteral) #date_format - | (TIME_ZONE EQUAL time_zone=stringLiteral) #time_zone - | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) #training_data_size - | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) #anomaly_score_threshold + : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) + | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) + | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) + | (OUTPUT_AFTER EQUAL output_after=integerLiteral) + | (TIME_DECAY EQUAL time_decay=decimalLiteral) + | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) + | (TIME_FIELD EQUAL time_field=stringLiteral) + | (DATE_FORMAT EQUAL date_format=stringLiteral) + | (TIME_ZONE EQUAL time_zone=stringLiteral) + | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) + | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) ; /** clauses */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index e7cc81ed71..2b25004f15 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -23,6 +23,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WhereCommandContext; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -32,7 +33,6 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -313,12 +313,13 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla */ @Override public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { - return new Kmeans(ctx.kmeansParameter().stream() - .map(p -> (Argument) internalVisitExpression(p)) - .collect(Collectors.toMap( - Argument::getArgName, Argument::getValue, - (value1, value2) -> value2) - )); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.kmeansParameter() + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + return new Kmeans(builder.build()); } /** @@ -326,12 +327,14 @@ public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { */ @Override public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { - return new AD(ctx.adParameter().stream() - .map(p -> (Argument) internalVisitExpression(p)) - .collect(Collectors.toMap( - Argument::getArgName, Argument::getValue, - (value1, value2) -> value2) - )); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.adParameter() + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + + return new AD(builder.build()); } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index c7f24e20b6..79612ff2cb 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -9,20 +9,15 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_rateContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_score_thresholdContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CentroidsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Date_formatContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Distance_typeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; @@ -32,43 +27,19 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IterationsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Number_of_treesContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Output_afterContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ParentheticBinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PercentileAggFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RelevanceExpressionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Sample_sizeContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Shingle_sizeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_decayContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_fieldContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_zoneContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Training_data_sizeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext; -import static org.opensearch.sql.ppl.utils.ArgumentFactory.getArgumentValue; -import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; -import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; -import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; -import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; -import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; -import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; -import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; -import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; -import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; -import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -129,82 +100,6 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); } - /** - * Kmeans arguments. - */ - @Override - public UnresolvedExpression visitCentroids(CentroidsContext ctx) { - return new Argument(CENTROIDS, getArgumentValue(ctx.centroids)); - } - - @Override - public UnresolvedExpression visitIterations(IterationsContext ctx) { - return new Argument(ITERATIONS, getArgumentValue(ctx.iterations)); - } - - @Override - public UnresolvedExpression visitDistance_type(Distance_typeContext ctx) { - return new Argument(DISTANCE_TYPE, getArgumentValue(ctx.distance_type)); - } - - /** - * AD arguments. - */ - @Override - public UnresolvedExpression visitNumber_of_trees(Number_of_treesContext ctx) { - return new Argument(NUMBER_OF_TREES, getArgumentValue(ctx.number_of_trees)); - } - - @Override - public UnresolvedExpression visitShingle_size(Shingle_sizeContext ctx) { - return new Argument(SHINGLE_SIZE, getArgumentValue(ctx.shingle_size)); - } - - @Override - public UnresolvedExpression visitSample_size(Sample_sizeContext ctx) { - return new Argument(SAMPLE_SIZE, getArgumentValue(ctx.sample_size)); - } - - @Override - public UnresolvedExpression visitOutput_after(Output_afterContext ctx) { - return new Argument(OUTPUT_AFTER, getArgumentValue(ctx.output_after)); - } - - @Override - public UnresolvedExpression visitTime_decay(Time_decayContext ctx) { - return new Argument(TIME_DECAY, getArgumentValue(ctx.time_decay)); - } - - @Override - public UnresolvedExpression visitAnomaly_rate(Anomaly_rateContext ctx) { - return new Argument(ANOMALY_RATE, getArgumentValue(ctx.anomaly_rate)); - } - - @Override - public UnresolvedExpression visitTime_field(Time_fieldContext ctx) { - return new Argument(TIME_FIELD, getArgumentValue(ctx.time_field)); - } - - @Override - public UnresolvedExpression visitDate_format(Date_formatContext ctx) { - return new Argument(DATE_FORMAT, getArgumentValue(ctx.date_format)); - } - - @Override - public UnresolvedExpression visitTime_zone(Time_zoneContext ctx) { - return new Argument(TIME_ZONE, getArgumentValue(ctx.time_zone)); - } - - @Override - public UnresolvedExpression visitTraining_data_size(Training_data_sizeContext ctx) { - return new Argument(TRAINING_DATA_SIZE, getArgumentValue(ctx.training_data_size)); - } - - @Override - public UnresolvedExpression visitAnomaly_score_threshold(Anomaly_score_thresholdContext ctx) { - return new Argument(ANOMALY_SCORE_THRESHOLD, getArgumentValue(ctx.anomaly_score_threshold)); - } - /** * Logical expression excluding boolean, comparison. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index a635013b1e..09afd2075f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -140,14 +140,12 @@ public static List getArgumentList(RareCommandContext ctx) { * @param ctx ParserRuleContext instance * @return Literal */ - public static Literal getArgumentValue(ParserRuleContext ctx) { + private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext - ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) - : ctx instanceof BooleanLiteralContext - ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) - : ctx instanceof DecimalLiteralContext - ? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE) - : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); + ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) + : ctx instanceof BooleanLiteralContext + ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 453ef4490e..5ee0e2be6b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -591,30 +591,12 @@ public void testKmeansCommand() { )); } - @Test - public void testKmeansCommand_withDuplicateParameters() { - assertEqual("source=t | kmeans centroids=3 centroids=2", - new Kmeans(relation("t"), ImmutableMap.builder() - .put("centroids", new Literal(2, DataType.INTEGER)) - .build() - )); - } - @Test public void testKmeansCommandWithoutParameter() { assertEqual("source=t | kmeans", new Kmeans(relation("t"), ImmutableMap.of())); } - @Test - public void test_fitRCFADCommand_withDuplicateParameters() { - assertEqual("source=t | AD shingle_size=10 shingle_size=8", - new AD(relation("t"), ImmutableMap.builder() - .put("shingle_size", new Literal(8, DataType.INTEGER)) - .build() - )); - } - @Test public void test_fitRCFADCommand_withoutDataFormat() { assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' "