Skip to content

Commit

Permalink
Support aggregate functions in Eval expressions
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Oct 9, 2024
1 parent e3a19dd commit c5706be
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,61 @@ class FlintSparkPPLEvalITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test eval comma separated expressions with stats functions") {
val frame = sql(s"""
| source = $testTable | eval col1 = max(age), col2 = avg(age), col3 = min(age), col4 = sum(age), col5 = count(age) | fields col1, col2, col3, col4, col5
| """.stripMargin)
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(70, 36.25, 20, 145, 4))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val evalProjectList = Seq(
Alias(
UnresolvedFunction("max", Seq(UnresolvedAttribute("age")), isDistinct = false),
"col1")(),
Alias(
UnresolvedFunction("avg", Seq(UnresolvedAttribute("age")), isDistinct = false),
"col2")(),
Alias(
UnresolvedFunction("min", Seq(UnresolvedAttribute("age")), isDistinct = false),
"col3")(),
Alias(
UnresolvedFunction("sum", Seq(UnresolvedAttribute("age")), isDistinct = false),
"col4")(),
Alias(
UnresolvedFunction("count", Seq(UnresolvedAttribute("age")), isDistinct = false),
"col5")())
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val project = Project(evalProjectList, table)
val expectedPlan = Project(
seq(
UnresolvedAttribute("col1"),
UnresolvedAttribute("col2"),
UnresolvedAttribute("col3"),
UnresolvedAttribute("col4"),
UnresolvedAttribute("col5")),
project)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval stats functions adding other field list should throw exception") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | eval col1 = max(age), col2 = avg(age), col3 = min(age), col4 = sum(age), col5 = count(age) | fields age, col1, col2, col3, col4, col5
| """.stripMargin))
assert(ex.getMessage().contains("UNRESOLVED_COLUMN"))
}

test("eval stats functions without fields command should throw exception") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | eval col1 = max(age), col2 = avg(age), col3 = min(age), col4 = sum(age), col5 = count(age)
| """.stripMargin))
assert(ex.getMessage().contains("MISSING_GROUP_BY"))
}

test("test complex eval expressions with fields command") {
val frame = sql(s"""
| source = $testTable | eval new_name = upper(name) | eval compound_field = concat('Hello ', if(like(new_name, 'HEL%'), 'World', name)) | fields new_name, compound_field
Expand Down Expand Up @@ -672,8 +727,7 @@ class FlintSparkPPLEvalITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

// Todo excluded fields not support yet
ignore("test single eval expression with excluded fields") {
test("test single eval expression with excluded fields") {
val frame = sql(s"""
| source = $testTable | eval new_field = "New Field" | fields - age
| """.stripMargin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ evalFunctionName
| systemFunctionName
| positionFunctionName
| coalesceFunctionName
| statsFunctionName
;

functionArgs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
Expand Down Expand Up @@ -439,9 +440,16 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) {
Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression());
aliases.add(alias);
}
if (context.getNamedParseExpressions().isEmpty()) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.<Seq<String>>empty()));
long statsFunctionsCount = node.getExpressionList().stream().map(Let::getExpression)
.filter(e -> e instanceof Function).map(f -> ((Function) f).getFuncName())
.filter(n -> BuiltinFunctionName.ofAggregation(n).isPresent()).count();
// An eval expression equals to add a projection to existing project list.
// So it must start with an UnresolvedStar except all eval expressions are aggregation functions with no fields command
if (statsFunctionsCount == node.getExpressionList().size() &&
context.getProjectedFields().stream().noneMatch(f -> f instanceof AllFields)) {
// do nothing
} else {
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
}
List<Expression> expressionList = visitExpressionList(aliases, context);
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort}
import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project, Sort}

class PPLLogicalPlanEvalTranslatorTestSuite
extends SparkFunSuite
Expand Down Expand Up @@ -150,6 +150,96 @@ class PPLLogicalPlanEvalTranslatorTestSuite
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test complex eval expressions - stats function") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(
pplParser,
"source=t | eval a = max(l) | eval b = avg(l) | eval c = min(l) | eval d = sum(l) | eval e = count(l) | fields a, b, c, d, e"),
context)

val evalProjectListA = Seq(
Alias(UnresolvedFunction("max", Seq(UnresolvedAttribute("l")), isDistinct = false), "a")())
val evalProjectListB = Seq(
Alias(UnresolvedFunction("avg", Seq(UnresolvedAttribute("l")), isDistinct = false), "b")())
val evalProjectListC = Seq(
Alias(UnresolvedFunction("min", Seq(UnresolvedAttribute("l")), isDistinct = false), "c")())
val evalProjectListD = Seq(
Alias(UnresolvedFunction("sum", Seq(UnresolvedAttribute("l")), isDistinct = false), "d")())
val evalProjectListE = Seq(
Alias(
UnresolvedFunction("count", Seq(UnresolvedAttribute("l")), isDistinct = false),
"e")())
val projectA = Project(evalProjectListA, UnresolvedRelation(Seq("t")))
val projectB = Project(evalProjectListB, projectA)
val projectC = Project(evalProjectListC, projectB)
val projectD = Project(evalProjectListD, projectC)
val projectE = Project(evalProjectListE, projectD)
val expectedPlan = Project(
seq(
UnresolvedAttribute("a"),
UnresolvedAttribute("b"),
UnresolvedAttribute("c"),
UnresolvedAttribute("d"),
UnresolvedAttribute("e")),
projectE)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test complex eval comma separated expressions - stats function") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(
pplParser,
"source=t | eval a = max(l), b = avg(l), c = min(l), d = sum(l), e = count(l) | fields a, b, c, d, e"),
context)

val evalProjectList = Seq(
Alias(UnresolvedFunction("max", Seq(UnresolvedAttribute("l")), isDistinct = false), "a")(),
Alias(UnresolvedFunction("avg", Seq(UnresolvedAttribute("l")), isDistinct = false), "b")(),
Alias(UnresolvedFunction("min", Seq(UnresolvedAttribute("l")), isDistinct = false), "c")(),
Alias(UnresolvedFunction("sum", Seq(UnresolvedAttribute("l")), isDistinct = false), "d")(),
Alias(
UnresolvedFunction("count", Seq(UnresolvedAttribute("l")), isDistinct = false),
"e")())
val project = Project(evalProjectList, UnresolvedRelation(Seq("t")))
val expectedPlan = Project(
seq(
UnresolvedAttribute("a"),
UnresolvedAttribute("b"),
UnresolvedAttribute("c"),
UnresolvedAttribute("d"),
UnresolvedAttribute("e")),
project)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test(
"test complex eval comma separated expressions - stats function - without fields command") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(
pplParser,
"source=t | eval a = max(l), b = avg(l), c = min(l), d = sum(l), e = count(l)"),
context)

val evalProjectList = Seq(
UnresolvedStar(None),
Alias(UnresolvedFunction("max", Seq(UnresolvedAttribute("l")), isDistinct = false), "a")(),
Alias(UnresolvedFunction("avg", Seq(UnresolvedAttribute("l")), isDistinct = false), "b")(),
Alias(UnresolvedFunction("min", Seq(UnresolvedAttribute("l")), isDistinct = false), "c")(),
Alias(UnresolvedFunction("sum", Seq(UnresolvedAttribute("l")), isDistinct = false), "d")(),
Alias(
UnresolvedFunction("count", Seq(UnresolvedAttribute("l")), isDistinct = false),
"e")())
val project = Project(evalProjectList, UnresolvedRelation(Seq("t")))
val expectedPlan = Project(Seq(UnresolvedStar(None)), project)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test complex eval expressions - compound function") {
val context = new CatalystPlanContext
val logPlan =
Expand Down Expand Up @@ -177,27 +267,30 @@ class PPLLogicalPlanEvalTranslatorTestSuite
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

// Todo fields-excluded command not supported
ignore("test eval expressions with fields-excluded command") {
test("test eval expressions with fields-excluded command") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=t | eval a = 1, b = 2 | fields - b"), context)

val projectList: Seq[NamedExpression] =
Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")())
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t")))
val expectedPlan = Project(
Seq(UnresolvedStar(None)),
DataFrameDropColumns(
Seq(UnresolvedAttribute("b")),
Project(projectList, UnresolvedRelation(Seq("t")))))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

// Todo fields-included command not supported
ignore("test eval expressions with fields-included command") {
test("test eval expressions with fields-included command") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=t | eval a = 1, b = 2 | fields + b"), context)

val projectList: Seq[NamedExpression] =
Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")())
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t")))
val expectedPlan =
Project(Seq(UnresolvedAttribute("b")), Project(projectList, UnresolvedRelation(Seq("t"))))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}

0 comments on commit c5706be

Please sign in to comment.