Skip to content

Commit

Permalink
A Generic ML Command in PPL (#971)
Browse files Browse the repository at this point in the history
* Add generic ml command in ppl.

Signed-off-by: Jing Zhang <[email protected]>

* Recover ml client dependency.

Signed-off-by: Jing Zhang <[email protected]>

* Address the comments I.

Signed-off-by: Jing Zhang <[email protected]>

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored Oct 31, 2022
1 parent 634e2ff commit c6b234c
Show file tree
Hide file tree
Showing 22 changed files with 822 additions and 3 deletions.
23 changes: 23 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.MODELID;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.STATUS;
import static org.opensearch.sql.utils.MLCommonsConstants.TASKID;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT;
import static org.opensearch.sql.utils.SystemIndexUtils.CATALOGS_TABLE_NAME;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -49,6 +57,7 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.ML;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
Expand Down Expand Up @@ -82,6 +91,7 @@
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalML;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalProject;
Expand Down Expand Up @@ -505,6 +515,19 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) {
return new LogicalAD(child, options);
}

/**
* Build {@link LogicalML} for ml command.
*/
@Override
public LogicalPlan visitML(ML node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
TypeEnvironment currentEnv = context.peek();
node.getOutputSchema(currentEnv).entrySet().stream()
.forEach(v -> currentEnv.define(new Symbol(Namespace.FIELD_NAME, v.getKey()), v.getValue()));

return new LogicalML(child, node.getArguments());
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package org.opensearch.sql.analysis;

import static org.opensearch.sql.analysis.symbol.Namespace.FIELD_NAME;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -82,7 +84,7 @@ public void define(Symbol symbol, ExprType type) {
* @param ref {@link ReferenceExpression}
*/
public void define(ReferenceExpression ref) {
define(new Symbol(Namespace.FIELD_NAME, ref.getAttr()), ref.type());
define(new Symbol(FIELD_NAME, ref.getAttr()), ref.type());
}

public void remove(Symbol symbol) {
Expand All @@ -93,6 +95,14 @@ public void remove(Symbol symbol) {
* Remove ref.
*/
public void remove(ReferenceExpression ref) {
remove(new Symbol(Namespace.FIELD_NAME, ref.getAttr()));
remove(new Symbol(FIELD_NAME, ref.getAttr()));
}

/**
* Clear all fields in the current environment.
*/
public void clearAllFields() {
lookupAllFields(FIELD_NAME).keySet().stream()
.forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.ML;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
Expand Down Expand Up @@ -266,6 +267,10 @@ public T visitAD(AD node, C context) {
return visitChildren(node, context);
}

public T visitML(ML node, C context) {
return visitChildren(node, context);
}

public T visitHighlightFunction(HighlightFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
135 changes: 135 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/ML.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC;
import static org.opensearch.sql.utils.MLCommonsConstants.CLUSTERID;
import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS;
import static org.opensearch.sql.utils.MLCommonsConstants.MODELID;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.STATUS;
import static org.opensearch.sql.utils.MLCommonsConstants.TASKID;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT;

import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.analysis.TypeEnvironment;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.data.type.ExprCoreType;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class ML extends UnresolvedPlan {
private UnresolvedPlan child;

private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitML(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}

private String getAction() {
return (String) arguments.get(ACTION).getValue();
}

/**
* Generate the ml output schema.
*
* @param env the current environment
* @return the schema
*/
public Map<String, ExprCoreType> getOutputSchema(TypeEnvironment env) {
switch (getAction()) {
case TRAIN:
env.clearAllFields();
return getTrainOutputSchema();
case PREDICT:
case TRAINANDPREDICT:
return getPredictOutputSchema();
default:
throw new IllegalArgumentException(
"Action error. Please indicate train, predict or trainandpredict.");
}
}

/**
* Generate the ml predict output schema.
*
* @return the schema
*/
public Map<String, ExprCoreType> getPredictOutputSchema() {
HashMap<String, ExprCoreType> res = new HashMap<>();
String algo = arguments.containsKey(ALGO) ? (String) arguments.get(ALGO).getValue() : null;
switch (algo) {
case KMEANS:
res.put(CLUSTERID, ExprCoreType.INTEGER);
break;
case RCF:
res.put(RCF_SCORE, ExprCoreType.DOUBLE);
if (arguments.containsKey(RCF_TIME_FIELD)) {
res.put(RCF_ANOMALY_GRADE, ExprCoreType.DOUBLE);
res.put((String) arguments.get(RCF_TIME_FIELD).getValue(), ExprCoreType.TIMESTAMP);
} else {
res.put(RCF_ANOMALOUS, ExprCoreType.BOOLEAN);
}
break;
default:
throw new IllegalArgumentException("Unsupported algorithm: " + algo);
}
return res;
}

/**
* Generate the ml train output schema.
*
* @return the schema
*/
public Map<String, ExprCoreType> getTrainOutputSchema() {
boolean isAsync = arguments.containsKey(ASYNC)
? (boolean) arguments.get(ASYNC).getValue() : false;
Map<String, ExprCoreType> res = new HashMap<>(Map.of(STATUS, ExprCoreType.STRING));
if (isAsync) {
res.put(TASKID, ExprCoreType.STRING);
} else {
res.put(MODELID, ExprCoreType.STRING);
}
return res;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;

/**
* ML logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalML extends LogicalPlan {
private final Map<String, Literal> arguments;

/**
* Constructor of LogicalML.
* @param child child logical plan
* @param arguments arguments of the algorithm
*/
public LogicalML(LogicalPlan child, Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitML(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}

public R visitML(LogicalML plan, C context) {
return visitNode(plan, context);
}

public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,8 @@ public R visitMLCommons(PhysicalPlan node, C context) {
public R visitAD(PhysicalPlan node, C context) {
return visitNode(node, context);
}

public R visitML(PhysicalPlan node, C context) {
return visitNode(node, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,20 @@ public class MLCommonsConstants {
public static final String CENTROIDS = "centroids";
public static final String ITERATIONS = "iterations";
public static final String DISTANCE_TYPE = "distance_type";

public static final String ACTION = "action";
public static final String TRAIN = "train";
public static final String PREDICT = "predict";
public static final String TRAINANDPREDICT = "trainandpredict";
public static final String ASYNC = "async";
public static final String ALGO = "algorithm";
public static final String KMEANS = "kmeans";
public static final String CLUSTERID = "ClusterID";
public static final String RCF = "rcf";
public static final String RCF_TIME_FIELD = "timeField";
public static final String MODELID = "model_id";
public static final String TASKID = "task_id";
public static final String STATUS = "status";
public static final String LIR = "linear_regression";
public static final String LIR_TARGET = "target";
}
Loading

0 comments on commit c6b234c

Please sign in to comment.