diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index e050908a87..188de5aa64 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -11,10 +11,18 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; +import static org.opensearch.sql.utils.MLCommonsConstants.MODELID; +import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP; +import static org.opensearch.sql.utils.MLCommonsConstants.STATUS; +import static org.opensearch.sql.utils.MLCommonsConstants.TASKID; import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT; import static org.opensearch.sql.utils.SystemIndexUtils.CATALOGS_TABLE_NAME; import com.google.common.collect.ImmutableList; @@ -49,6 +57,7 @@ import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; +import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -82,6 +91,7 @@ import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalProject; @@ -505,6 +515,19 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) { return new LogicalAD(child, options); } + /** + * Build {@link LogicalML} for ml command. + */ + @Override + public LogicalPlan visitML(ML node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + TypeEnvironment currentEnv = context.peek(); + node.getOutputSchema(currentEnv).entrySet().stream() + .forEach(v -> currentEnv.define(new Symbol(Namespace.FIELD_NAME, v.getKey()), v.getValue())); + + return new LogicalML(child, node.getArguments()); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java index 1be195e056..c86d8109ad 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java +++ b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java @@ -6,6 +6,8 @@ package org.opensearch.sql.analysis; +import static org.opensearch.sql.analysis.symbol.Namespace.FIELD_NAME; + import java.util.LinkedHashMap; import java.util.Map; import java.util.Optional; @@ -82,7 +84,7 @@ public void define(Symbol symbol, ExprType type) { * @param ref {@link ReferenceExpression} */ public void define(ReferenceExpression ref) { - define(new Symbol(Namespace.FIELD_NAME, ref.getAttr()), ref.type()); + define(new Symbol(FIELD_NAME, ref.getAttr()), ref.type()); } public void remove(Symbol symbol) { @@ -93,6 +95,14 @@ public void remove(Symbol symbol) { * Remove ref. */ public void remove(ReferenceExpression ref) { - remove(new Symbol(Namespace.FIELD_NAME, ref.getAttr())); + remove(new Symbol(FIELD_NAME, ref.getAttr())); + } + + /** + * Clear all fields in the current environment. + */ + public void clearAllFields() { + lookupAllFields(FIELD_NAME).keySet().stream() + .forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v))); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 295db7680f..60e7d6f06e 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -43,6 +43,7 @@ import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; +import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -266,6 +267,10 @@ public T visitAD(AD node, C context) { return visitChildren(node, context); } + public T visitML(ML node, C context) { + return visitChildren(node, context); + } + public T visitHighlightFunction(HighlightFunction node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/ML.java b/core/src/main/java/org/opensearch/sql/ast/tree/ML.java new file mode 100644 index 0000000000..2f83a993b7 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/ML.java @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.ast.tree; + +import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; +import static org.opensearch.sql.utils.MLCommonsConstants.ALGO; +import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC; +import static org.opensearch.sql.utils.MLCommonsConstants.CLUSTERID; +import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS; +import static org.opensearch.sql.utils.MLCommonsConstants.MODELID; +import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.STATUS; +import static org.opensearch.sql.utils.MLCommonsConstants.TASKID; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT; + +import com.google.common.collect.ImmutableList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.analysis.TypeEnvironment; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.type.ExprCoreType; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +@AllArgsConstructor +public class ML extends UnresolvedPlan { + private UnresolvedPlan child; + + private final Map arguments; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitML(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + private String getAction() { + return (String) arguments.get(ACTION).getValue(); + } + + /** + * Generate the ml output schema. + * + * @param env the current environment + * @return the schema + */ + public Map getOutputSchema(TypeEnvironment env) { + switch (getAction()) { + case TRAIN: + env.clearAllFields(); + return getTrainOutputSchema(); + case PREDICT: + case TRAINANDPREDICT: + return getPredictOutputSchema(); + default: + throw new IllegalArgumentException( + "Action error. Please indicate train, predict or trainandpredict."); + } + } + + /** + * Generate the ml predict output schema. + * + * @return the schema + */ + public Map getPredictOutputSchema() { + HashMap res = new HashMap<>(); + String algo = arguments.containsKey(ALGO) ? (String) arguments.get(ALGO).getValue() : null; + switch (algo) { + case KMEANS: + res.put(CLUSTERID, ExprCoreType.INTEGER); + break; + case RCF: + res.put(RCF_SCORE, ExprCoreType.DOUBLE); + if (arguments.containsKey(RCF_TIME_FIELD)) { + res.put(RCF_ANOMALY_GRADE, ExprCoreType.DOUBLE); + res.put((String) arguments.get(RCF_TIME_FIELD).getValue(), ExprCoreType.TIMESTAMP); + } else { + res.put(RCF_ANOMALOUS, ExprCoreType.BOOLEAN); + } + break; + default: + throw new IllegalArgumentException("Unsupported algorithm: " + algo); + } + return res; + } + + /** + * Generate the ml train output schema. + * + * @return the schema + */ + public Map getTrainOutputSchema() { + boolean isAsync = arguments.containsKey(ASYNC) + ? (boolean) arguments.get(ASYNC).getValue() : false; + Map res = new HashMap<>(Map.of(STATUS, ExprCoreType.STRING)); + if (isAsync) { + res.put(TASKID, ExprCoreType.STRING); + } else { + res.put(MODELID, ExprCoreType.STRING); + } + return res; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalML.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalML.java new file mode 100644 index 0000000000..c54ee92e08 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalML.java @@ -0,0 +1,33 @@ +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; + +/** + * ML logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalML extends LogicalPlan { + private final Map arguments; + + /** + * Constructor of LogicalML. + * @param child child logical plan + * @param arguments arguments of the algorithm + */ + public LogicalML(LogicalPlan child, Map arguments) { + super(Collections.singletonList(child)); + this.arguments = arguments; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitML(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index df23b9cd20..28539562e7 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -78,6 +78,10 @@ public R visitMLCommons(LogicalMLCommons plan, C context) { return visitNode(plan, context); } + public R visitML(LogicalML plan, C context) { + return visitNode(plan, context); + } + public R visitAD(LogicalAD plan, C context) { return visitNode(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 646aae8220..63dd05cc6b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -79,4 +79,8 @@ public R visitMLCommons(PhysicalPlan node, C context) { public R visitAD(PhysicalPlan node, C context) { return visitNode(node, context); } + + public R visitML(PhysicalPlan node, C context) { + return visitNode(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java index 883d012d2f..90bca8fe8a 100644 --- a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -25,4 +25,20 @@ public class MLCommonsConstants { public static final String CENTROIDS = "centroids"; public static final String ITERATIONS = "iterations"; public static final String DISTANCE_TYPE = "distance_type"; + + public static final String ACTION = "action"; + public static final String TRAIN = "train"; + public static final String PREDICT = "predict"; + public static final String TRAINANDPREDICT = "trainandpredict"; + public static final String ASYNC = "async"; + public static final String ALGO = "algorithm"; + public static final String KMEANS = "kmeans"; + public static final String CLUSTERID = "ClusterID"; + public static final String RCF = "rcf"; + public static final String RCF_TIME_FIELD = "timeField"; + public static final String MODELID = "model_id"; + public static final String TASKID = "task_id"; + public static final String STATUS = "status"; + public static final String LIR = "linear_regression"; + public static final String LIR_TARGET = "target"; } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 33e5985bd7..97c560d505 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -37,6 +37,21 @@ import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; +import static org.opensearch.sql.utils.MLCommonsConstants.ALGO; +import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC; +import static org.opensearch.sql.utils.MLCommonsConstants.CLUSTERID; +import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS; +import static org.opensearch.sql.utils.MLCommonsConstants.MODELID; +import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIME_FIELD; +import static org.opensearch.sql.utils.MLCommonsConstants.STATUS; +import static org.opensearch.sql.utils.MLCommonsConstants.TASKID; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -59,6 +74,7 @@ import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; @@ -1067,4 +1083,119 @@ public void show_catalogs() { } + @Test + public void ml_relation_unsupported_action() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal("unsupported", DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + }}; + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields()))); + assertEquals( + "Action error. Please indicate train, predict or trainandpredict.", + exception.getMessage()); + } + + @Test + public void ml_relation_unsupported_algorithm() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal("unsupported", DataType.STRING)); + }}; + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields()))); + assertEquals( + "Unsupported algorithm: unsupported", + exception.getMessage()); + } + + @Test + public void ml_relation_train_sync() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal(TRAIN, DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + }}; + + LogicalPlan actual = analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(STATUS, DSL.ref(STATUS, STRING)))); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(MODELID, DSL.ref(MODELID, STRING)))); + } + + @Test + public void ml_relation_train_async() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal(TRAIN, DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + put(ASYNC, new Literal(true, DataType.BOOLEAN)); + }}; + + LogicalPlan actual = analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(STATUS, DSL.ref(STATUS, STRING)))); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(TASKID, DSL.ref(TASKID, STRING)))); + } + + @Test + public void ml_relation_predict_kmeans() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + }}; + + LogicalPlan actual = analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + assertTrue(((LogicalProject) actual).getProjectList().size() >= 1); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(CLUSTERID, DSL.ref(CLUSTERID, INTEGER)))); + } + + @Test + public void ml_relation_predict_rcf_with_time_field() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal(RCF, DataType.STRING)); + put(RCF_TIME_FIELD, new Literal("ts", DataType.STRING)); + }}; + + LogicalPlan actual = analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + assertTrue(((LogicalProject) actual).getProjectList().size() >= 3); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(RCF_SCORE, DSL.ref(RCF_SCORE, DOUBLE)))); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(RCF_ANOMALY_GRADE, DSL.ref(RCF_ANOMALY_GRADE, DOUBLE)))); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named("ts", DSL.ref("ts", TIMESTAMP)))); + } + + @Test + public void ml_relation_predict_rcf_without_time_field() { + Map argumentMap = new HashMap<>() {{ + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal(RCF, DataType.STRING)); + }}; + + LogicalPlan actual = analyze(AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(RCF_SCORE, DSL.ref(RCF_SCORE, DOUBLE)))); + assertTrue(((LogicalProject) actual).getProjectList() + .contains(DSL.named(RCF_ANOMALOUS, DSL.ref(RCF_ANOMALOUS, BOOLEAN)))); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 329708b7d8..03eeb9c626 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -143,6 +143,18 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { }); assertNull(ad.accept(new LogicalPlanNodeVisitor() { }, null)); + + LogicalPlan ml = new LogicalML(LogicalPlanDSL.relation("schema", table), + new HashMap() {{ + put("action", new Literal("train", DataType.STRING)); + put("algorithm", new Literal("rcf", DataType.STRING)); + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }); + assertNull(ml.accept(new LogicalPlanNodeVisitor() { + }, null)); } private static class NodesCount extends LogicalPlanNodeVisitor { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index cd561f3c09..8780177c88 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -150,6 +150,14 @@ public void test_visitAD() { assertNull(physicalPlanNodeVisitor.visitAD(plan, null)); } + @Test + public void test_visitML() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor = + new PhysicalPlanNodeVisitor() {}; + + assertNull(physicalPlanNodeVisitor.visitML(plan, null)); + } + public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor { public String print(PhysicalPlan node) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 45d2b12620..f06ecb8576 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -10,6 +10,7 @@ import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; +import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.planner.physical.AggregationOperator; import org.opensearch.sql.planner.physical.DedupeOperator; import org.opensearch.sql.planner.physical.EvalOperator; @@ -150,6 +151,16 @@ public PhysicalPlan visitAD(PhysicalPlan node, Object context) { ); } + @Override + public PhysicalPlan visitML(PhysicalPlan node, Object context) { + MLOperator mlOperator = (MLOperator) node; + return doProtect( + new MLOperator(visitInput(mlOperator.getInput(), context), + mlOperator.getArguments(), + mlOperator.getNodeClient()) + ); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java index 9003d2ec47..e1f12fb8a7 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -6,6 +6,10 @@ package org.opensearch.sql.opensearch.planner.physical; +import static org.opensearch.sql.utils.MLCommonsConstants.MODELID; +import static org.opensearch.sql.utils.MLCommonsConstants.STATUS; +import static org.opensearch.sql.utils.MLCommonsConstants.TASKID; + import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.Iterator; @@ -28,7 +32,10 @@ import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprFloatValue; @@ -216,6 +223,64 @@ protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, .actionGet(30, TimeUnit.SECONDS); } + /** + * get ml-commons train, predict and trainandpredict result. + * @param inputDataFrame input data frame + * @param arguments ml parameters + * @param nodeClient node client + * @return ml-commons result + */ + protected MLOutput getMLOutput(DataFrame inputDataFrame, + Map arguments, + NodeClient nodeClient) { + MLInput mlinput = MLInput.builder() + .inputDataset(new DataFrameInputDataset(inputDataFrame)) + //Just the placeholders for algorithm and parameters which must be initialized. + //They will be overridden in ml client. + .algorithm(FunctionName.SAMPLE_ALGO) + .parameters(new SampleAlgoParams(0)) + .build(); + + MachineLearningNodeClient machineLearningClient = + MLClient.getMLClient(nodeClient); + + return machineLearningClient + .run(mlinput, arguments) + .actionGet(30, TimeUnit.SECONDS); + } + + /** + * iterate result and built it into ExprTupleValue. + * @param inputRowIter input row iterator + * @param inputDataFrame input data frame + * @param mlResult train/predict result + * @param resultRowIter predict result iterator + * @return result in ExprTupleValue format + */ + protected ExprTupleValue buildPPLResult(boolean isPredict, + Iterator inputRowIter, + DataFrame inputDataFrame, + MLOutput mlResult, + Iterator resultRowIter) { + if (isPredict) { + return buildResult(inputRowIter, + inputDataFrame, + (MLPredictionOutput) mlResult, + resultRowIter); + } else { + return buildTrainResult((MLTrainingOutput) mlResult); + } + } + + protected ExprTupleValue buildTrainResult(MLTrainingOutput trainResult) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.put(MODELID, new ExprStringValue(trainResult.getModelId())); + resultBuilder.put(TASKID, new ExprStringValue(trainResult.getTaskId())); + resultBuilder.put(STATUS, new ExprStringValue(trainResult.getStatus())); + + return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + } + private static class MLInputRows extends LinkedList> { /** * Add tuple value to input map, skip if any value is null. diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java new file mode 100644 index 0000000000..938ff60157 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +/** + * ml-commons Physical operator to call machine learning interface to get results for + * algorithm execution. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class MLOperator extends MLCommonsOperatorActions { + @Getter + private final PhysicalPlan input; + + @Getter + private final Map arguments; + + @Getter + private final NodeClient nodeClient; + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + @Override + public void open() { + super.open(); + DataFrame inputDataFrame = generateInputDataset(input); + Map args = processArgs(arguments); + + MLOutput mlOutput = getMLOutput(inputDataFrame, args, nodeClient); + final Iterator inputRowIter = inputDataFrame.iterator(); + // Only need to check train here, as action should be already checked in ml client. + final boolean isPrediction = ((String) args.get("action")).equals("train") ? false : true; + //For train, only one row to return. + final Iterator trainIter = new ArrayList() { + { + add("train"); + } + }.iterator(); + final Iterator resultRowIter = isPrediction + ? ((MLPredictionOutput) mlOutput).getPredictionResult().iterator() + : null; + iterator = new Iterator() { + @Override + public boolean hasNext() { + if (isPrediction) { + return inputRowIter.hasNext(); + } else { + boolean res = trainIter.hasNext(); + if (res) { + trainIter.next(); + } + return res; + } + } + + @Override + public ExprValue next() { + return buildPPLResult(isPrediction, inputRowIter, inputDataFrame, mlOutput, resultRowIter); + } + }; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitML(this, context); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + protected Map processArgs(Map arguments) { + Map res = new HashMap<>(); + arguments.forEach((k, v) -> res.put(k, v.getValue())); + return res; + } +} + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 2849fbbec9..9ebdc12ba2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -24,6 +24,7 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; +import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -34,6 +35,7 @@ import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalRelation; @@ -205,6 +207,12 @@ public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { node.getArguments(), client.getNodeClient()); } + @Override + public PhysicalPlan visitML(LogicalML node, OpenSearchIndexScan context) { + return new MLOperator(visitChild(node, context), + node.getArguments(), client.getNodeClient()); + } + @Override public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { context.getRequestBuilder().pushDownHighlight( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index fded7848b6..857ff601e1 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -56,6 +56,7 @@ import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; +import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -293,6 +294,26 @@ public void testVisitAD() { executionProtector.visitAD(adOperator, null)); } + @Test + public void testVisitML() { + NodeClient nodeClient = mock(NodeClient.class); + MLOperator mlOperator = + new MLOperator( + values(emptyList()), + new HashMap() {{ + put("action", new Literal("train", DataType.STRING)); + put("algorithm", new Literal("rcf", DataType.STRING)); + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + }}, + nodeClient + ); + + assertEquals(executionProtector.doProtect(mlOperator), + executionProtector.visitML(mlOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java new file mode 100644 index 0000000000..7a73468391 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; +import static org.opensearch.sql.utils.MLCommonsConstants.ALGO; +import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS; +import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT; +import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN; + +import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.runner.RunWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +@RunWith(MockitoJUnitRunner.Silent.class) +public class MLOperatorTest { + @Mock + private PhysicalPlan input; + + @Mock + PlainActionFuture actionFuture; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; + + private MLOperator mlOperator; + Map arguments = new HashMap<>(); + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private MachineLearningNodeClient machineLearningNodeClient; + + void setUp(boolean isPredict) { + arguments.put("k1",AstDSL.intLiteral(3)); + arguments.put("k2",AstDSL.stringLiteral("v1")); + arguments.put("k3",AstDSL.booleanLiteral(true)); + arguments.put("k4",AstDSL.doubleLiteral(2.0D)); + arguments.put("k5",AstDSL.shortLiteral((short)2)); + arguments.put("k6",AstDSL.longLiteral(2L)); + arguments.put("k7",AstDSL.floatLiteral(2F)); + + mlOperator = new MLOperator(input, arguments, nodeClient); + when(input.hasNext()).thenReturn(true).thenReturn(false); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.put("k1", new ExprIntegerValue(2)); + when(input.next()).thenReturn(ExprTupleValue.fromExprValueMap(resultBuilder.build())); + + DataFrame dataFrame = DataFrameBuilder + .load(Collections.singletonList( + ImmutableMap.builder().put("result-k1", 2D) + .put("result-k2", 1) + .put("result-k3", "v3") + .put("result-k4", true) + .put("result-k5", (short)2) + .put("result-k6", 2L) + .put("result-k7", 2F) + .build()) + ); + + MLOutput mlOutput; + if (isPredict) { + mlOutput = MLPredictionOutput.builder() + .taskId("test_task_id") + .status("test_status") + .predictionResult(dataFrame) + .build(); + } else { + mlOutput = MLTrainingOutput.builder() + .taskId("test_task_id") + .status("test_status") + .modelId("test_model_id") + .build(); + } + + when(actionFuture.actionGet(anyLong(), eq(TimeUnit.SECONDS))) + .thenReturn(mlOutput); + when(machineLearningNodeClient.run(any(MLInput.class), any())) + .thenReturn(actionFuture); + } + + void setUpPredict() { + arguments.put(ACTION,AstDSL.stringLiteral(PREDICT)); + arguments.put(ALGO,AstDSL.stringLiteral(KMEANS)); + arguments.put("modelid",AstDSL.stringLiteral("dummyID")); + setUp(true); + } + + void setUpTrain() { + arguments.put(ACTION,AstDSL.stringLiteral(TRAIN)); + arguments.put(ALGO,AstDSL.stringLiteral(KMEANS)); + setUp(false); + } + + @Test + public void testOpenPredict() { + setUpPredict(); + try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient); + mlOperator.open(); + assertTrue(mlOperator.hasNext()); + assertNotNull(mlOperator.next()); + assertFalse(mlOperator.hasNext()); + } + } + + @Test + public void testOpenTrain() { + setUpTrain(); + try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient); + mlOperator.open(); + assertTrue(mlOperator.hasNext()); + assertNotNull(mlOperator.next()); + assertFalse(mlOperator.hasNext()); + } + } + + @Test + public void testAccept() { + setUpPredict(); + try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient); + PhysicalPlanNodeVisitor physicalPlanNodeVisitor + = new PhysicalPlanNodeVisitor() {}; + assertNull(mlOperator.accept(physicalPlanNodeVisitor, null)); + } + } + +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index ced87a7d31..a74c5fcbd4 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -23,6 +23,7 @@ import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.storage.Table; @@ -77,6 +78,16 @@ public void visitAD() { assertNotNull(implementor.visitAD(node, indexScan)); } + @Test + public void visitML() { + LogicalML node = Mockito.mock(LogicalML.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + assertNotNull(implementor.visitML(node, indexScan)); + } + @Test public void visitHighlight() { LogicalHighlight node = Mockito.mock(LogicalHighlight.class, diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index ffb9780faf..79c812949f 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -34,6 +34,7 @@ PATTERNS: 'PATTERNS'; NEW_FIELD: 'NEW_FIELD'; KMEANS: 'KMEANS'; AD: 'AD'; +ML: 'ML'; // COMMAND ASSIST KEYWORDS AS: 'AS'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 7a42c7fa2b..11e1c4c71f 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -26,7 +26,7 @@ pplCommands commands : whereCommand | fieldsCommand | renameCommand | statsCommand | dedupCommand | sortCommand | evalCommand | headCommand - | topCommand | rareCommand | grokCommand | parseCommand | patternsCommand | kmeansCommand | adCommand; + | topCommand | rareCommand | grokCommand | parseCommand | patternsCommand | kmeansCommand | adCommand | mlCommand; searchCommand : (SEARCH)? fromClause #searchFrom @@ -149,6 +149,14 @@ adParameter | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) ; +mlCommand + : ML (mlArg)* + ; + +mlArg + : (argName=ident EQUAL argValue=literalValue) + ; + /** clauses */ fromClause : SOURCE EQUAL tableSourceClause diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index b0d17940a4..13d8f8cddf 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -53,6 +53,7 @@ import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -411,6 +412,20 @@ public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { return new AD(builder.build()); } + /** + * ml command. + */ + @Override + public UnresolvedPlan visitMlCommand(OpenSearchPPLParser.MlCommandContext ctx) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.mlArg() + .forEach(x -> { + builder.put(x.argName.getText(), + (Literal) internalVisitExpression(x.argValue)); + }); + return new ML(builder.build()); + } + /** * Get original text in query. */ diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 9bcbe66330..eb2f651ac7 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -55,6 +55,7 @@ import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; @@ -712,6 +713,20 @@ public void testKmeansCommandWithoutParameter() { new Kmeans(relation("t"), ImmutableMap.of())); } + @Test + public void testMLCommand() { + assertEqual("source=t | ml action='trainandpredict' " + + "algorithm='kmeans' centroid=3 iteration=2 dist_type='l1'", + new ML(relation("t"), ImmutableMap.builder() + .put("action", new Literal("trainandpredict", DataType.STRING)) + .put("algorithm", new Literal("kmeans", DataType.STRING)) + .put("centroid", new Literal(3, DataType.INTEGER)) + .put("iteration", new Literal(2, DataType.INTEGER)) + .put("dist_type", new Literal("l1", DataType.STRING)) + .build() + )); + } + @Test public void testDescribeCommand() { assertEqual("describe t",