Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPL integration with AD and ml-commons #468

Merged
merged 10 commits into from
Mar 9, 2022
4 changes: 4 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ test.finalizedBy(project.tasks.jacocoTestReport)
jacocoTestCoverageVerification {
violationRules {
rule {
element = 'CLASS'
excludes = [
'org.opensearch.sql.utils.MLCommonsConstants'
]
limit {
counter = 'LINE'
minimum = 1.0
Expand Down
44 changes: 44 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand All @@ -31,11 +37,13 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand All @@ -58,11 +66,13 @@
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalDedupe;
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalProject;
import org.opensearch.sql.planner.logical.LogicalRareTopN;
Expand Down Expand Up @@ -395,6 +405,40 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) {
return new LogicalValues(valueExprs);
}

/**
* Build {@link LogicalMLCommons} for Kmeans command.
*/
@Override
public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
List<Argument> options = node.getOptions();

TypeEnvironment currentEnv = context.peek();
currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER);

return new LogicalMLCommons(child, "kmeans", options);
}

/**
* Build {@link LogicalAD} for AD command.
*/
@Override
public LogicalPlan visitAD(AD node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
java.util.Map<String, Literal> options = node.getArguments();

TypeEnvironment currentEnv = context.peek();

currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE);
if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN);
} else {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE);
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP);
}
return new LogicalAD(child, options);
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand Down Expand Up @@ -239,4 +241,12 @@ public T visitLimit(Limit node, C context) {
public T visitSpan(Span node, C context) {
return visitChildren(node, context);
}

public T visitKmeans(Kmeans node, C context) {
return visitChildren(node, context);
}

public T visitAD(AD node, C context) {
return visitChildren(node, context);
}
}
8 changes: 8 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ public static Literal longLiteral(Long value) {
return literal(value, DataType.LONG);
}

public static Literal shortLiteral(Short value) {
return literal(value, DataType.SHORT);
}

public static Literal floatLiteral(Float value) {
return literal(value, DataType.FLOAT);
}

public static Literal dateLiteral(String value) {
return literal(value, DataType.DATE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public enum DataType {

INTEGER(ExprCoreType.INTEGER),
LONG(ExprCoreType.LONG),
SHORT(ExprCoreType.SHORT),
FLOAT(ExprCoreType.FLOAT),
DOUBLE(ExprCoreType.DOUBLE),
STRING(ExprCoreType.STRING),
BOOLEAN(ExprCoreType.BOOLEAN),
Expand Down
47 changes: 47 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/AD.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class AD extends UnresolvedPlan {
private UnresolvedPlan child;

private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitAD(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
46 changes: 46 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Argument;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class Kmeans extends UnresolvedPlan {
private UnresolvedPlan child;

private final List<Argument> options;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitKmeans(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;

/*
* AD logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalAD extends LogicalPlan {
private final Map<String, Literal> arguments;

/**
* Constructor of LogicalAD.
* @param child child logical plan
* @param arguments arguments of the algorithm
*/
public LogicalAD(LogicalPlan child, Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitAD(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Argument;

/**
* ml-commons logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalMLCommons extends LogicalPlan {
private final String algorithm;

private final List<Argument> arguments;

/**
* Constructor of LogicalMLCommons.
* @param child child logical plan
* @param algorithm algorithm name
* @param arguments arguments of the algorithm
*/
public LogicalMLCommons(LogicalPlan child, String algorithm,
List<Argument> arguments) {
super(Collections.singletonList(child));
this.algorithm = algorithm;
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitMLCommons(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,12 @@ public R visitRareTopN(LogicalRareTopN plan, C context) {
public R visitLimit(LogicalLimit plan, C context) {
return visitNode(plan, context);
}

public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}

public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,13 @@ public R visitLimit(LimitOperator node, C context) {
return visitNode(node, context);
}

public R visitMLCommons(PhysicalPlan node, C context) {
return visitNode(node, context);
}

public R visitAD(PhysicalPlan node, C context) {
return visitNode(node, context);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.sql.utils;

public class MLCommonsConstants {

public static final String SHINGLE_SIZE = "shingle_size";
public static final String TIME_DECAY = "time_decay";
public static final String TIME_FIELD = "time_field";

public static final String RCF_SCORE = "score";
public static final String RCF_ANOMALOUS = "anomalous";
public static final String RCF_ANOMALY_GRADE = "anomaly_grade";
public static final String RCF_TIMESTAMP = "timestamp";
}
Loading