Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-1371][VL] Support First/Last aggregate functions #1581

Merged
merged 7 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ object CHExpressionUtil {
BIT_AND_AGG -> Set(EMPTY_TYPE),
BIT_XOR_AGG -> Set(EMPTY_TYPE),
CORR -> Set(EMPTY_TYPE),
FIRST -> Set(EMPTY_TYPE),
LAST -> Set(EMPTY_TYPE),
COVAR_POP -> Set(EMPTY_TYPE),
COVAR_SAMP -> Set(EMPTY_TYPE)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,49 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
.set("spark.sql.sources.useV1SourceList", "avro")
}


test("count") {
val df = runQueryAndCompare(
"select count(*) from lineitem where l_partkey in (1552, 674, 1062)") {
checkOperatorMatch[GlutenHashAggregateExecTransformer] }
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
"select count(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("avg") {
val df = runQueryAndCompare(
"select avg(l_partkey) from lineitem where l_partkey < 1000") {
checkOperatorMatch[GlutenHashAggregateExecTransformer] }
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
"select avg(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
runQueryAndCompare(
"select avg(cast (l_quantity as DECIMAL(12, 2))), " +
"count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
runQueryAndCompare(
"select avg(cast (l_quantity as DECIMAL(22, 2))), " +
"count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("sum") {
Expand All @@ -91,8 +102,10 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
runQueryAndCompare(
"select sum(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
runQueryAndCompare(
"select sum(cast (l_quantity as DECIMAL(22, 2))) from lineitem") {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
Expand All @@ -101,14 +114,18 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
"select sum(cast (l_quantity as DECIMAL(12, 2))), " +
"count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
runQueryAndCompare(
"select sum(cast (l_quantity as DECIMAL(22, 2))), " +
"count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("min and max") {
Expand All @@ -119,8 +136,10 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
runQueryAndCompare(
"select min(l_partkey), max(l_partkey), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("groupby") {
Expand Down Expand Up @@ -154,15 +173,17 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
runQueryAndCompare(
"select stddev_samp(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("stddev_pop") {
runQueryAndCompare(
"""
|select stddev_pop(l_quantity) from lineitem;
|""".stripMargin) {
|""".stripMargin) {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
Expand All @@ -175,8 +196,10 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
runQueryAndCompare(
"select stddev_pop(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("var_samp") {
Expand All @@ -196,8 +219,10 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
runQueryAndCompare(
"select var_samp(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("var_pop") {
Expand All @@ -217,18 +242,20 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
runQueryAndCompare(
"select var_pop(l_quantity), count(distinct l_partkey) from lineitem") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4)
}}
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("bit_and bit_or bit_xor") {
val bitAggs = Seq("bit_and", "bit_or", "bit_xor")
for (func <- bitAggs) {
runQueryAndCompare(
s"""
|select ${func}(l_linenumber) from lineitem
|group by l_orderkey;
|""".stripMargin) {
|select ${func}(l_linenumber) from lineitem
|group by l_orderkey;
|""".stripMargin) {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
Expand Down Expand Up @@ -283,6 +310,56 @@ class VeloxAggregateFunctionsSuite extends WholeStageTransformerSuite {
}
}

test("first") {
runQueryAndCompare(
s"""
|select first(l_linenumber), first(l_linenumber, true) from lineitem;
|""".stripMargin) {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
s"""
|select first_value(l_linenumber), first_value(l_linenumber, true) from lineitem
|group by l_orderkey;
|""".stripMargin) {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
s"""
|select first(l_linenumber), first(l_linenumber, true), count(distinct l_partkey) from lineitem
|""".stripMargin) { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}

test("last") {
runQueryAndCompare(
s"""
|select last(l_linenumber), last(l_linenumber, true) from lineitem;
|""".stripMargin) {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
s"""
|select last_value(l_linenumber), last_value(l_linenumber, true) from lineitem
|group by l_orderkey;
|""".stripMargin) {
checkOperatorMatch[GlutenHashAggregateExecTransformer]
}
runQueryAndCompare(
s"""
|select last(l_linenumber), last(l_linenumber, true), count(distinct l_partkey) from lineitem
|""".stripMargin) { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}
}
}


test("distinct functions") {
runQueryAndCompare("SELECT sum(DISTINCT l_partkey), count(*) FROM lineitem") { df => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,10 @@ abstract class HashAggregateExecBaseTransformer(
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
case Average(_, _) =>
case _: Average | _: First | _: Last =>
mode match {
case Partial | PartialMerge =>
val avg = aggregateFunc.asInstanceOf[Average]
val aggBufferAttr = avg.inputAggBufferAttributes
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object AggregateFunctionsBuilder {
def create(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]

val substraitAggFuncName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass)
var substraitAggFuncName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass)
if (substraitAggFuncName.isEmpty) {
throw new UnsupportedOperationException(s"Could not find valid a substrait mapping name for $aggregateFunc.")
}
Expand All @@ -38,6 +38,14 @@ object AggregateFunctionsBuilder {
throw new UnsupportedOperationException(s"Aggregate function not supported for $aggregateFunc.")
}

aggregateFunc match {
case first @ First(_, ignoreNull) =>
if (ignoreNull) substraitAggFuncName = Some(ExpressionMappings.FIRST_IGNORE_NULL)
case last @ Last(_, ignoreNulls) =>
if (ignoreNulls) substraitAggFuncName = Some(ExpressionMappings.LAST_IGNORE_NULL)
case _ =>
}

val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType)

ExpressionBuilder.newScalarFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ object ExpressionMappings {
final val COVAR_POP = "covar_pop"
final val COVAR_SAMP = "covar_samp"
final val LAST = "last"
final val LAST_IGNORE_NULL = "last_ignore_null"
final val FIRST = "first"
final val FIRST_IGNORE_NULL = "first_ignore_null"

// Function names used by Substrait plan.
final val ADD = "add"
Expand Down Expand Up @@ -423,7 +426,8 @@ object ExpressionMappings {
Sig[Corr](CORR),
Sig[CovPopulation](COVAR_POP),
Sig[CovSample](COVAR_SAMP),
Sig[Last](LAST)
Sig[Last](LAST),
Sig[First](FIRST)
)

/** Mapping Spark window expression to Substrait function name */
Expand Down
Loading