Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-207] Fix aggregate and refresh UT test script (#426)
Browse files Browse the repository at this point in the history
* fix expr id difference in Partial Aggregate

* add fallback to window

* refresh ut

* use stol instead of stoi

* enable ut full test
  • Loading branch information
rui-mo authored Jul 30, 2021
1 parent 9f75f6f commit 06bd109
Show file tree
Hide file tree
Showing 17 changed files with 286 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ case class ColumnarHashAggregateExec(
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
var resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec
with ColumnarCodegenSupport
Expand All @@ -76,14 +76,28 @@ case class ColumnarHashAggregateExec(
val numaBindingInfo = ColumnarPluginConfig.getConf.numaBindingInfo
override def supportsColumnar = true

var resAttributes: Seq[Attribute] = resultExpressions.map(_.toAttribute)
if (aggregateExpressions != null && aggregateExpressions.nonEmpty) {
aggregateExpressions.head.mode match {
case Partial =>
// To fix the expression ids in result expressions being different with those from
// inputAggBufferAttributes, in Partial Aggregate,
// result attributes are recalculated to set the result expressions.
resAttributes = groupingExpressions.map(_.toAttribute) ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
resultExpressions = resAttributes
case _ =>
}
}

// Members declared in org.apache.spark.sql.execution.AliasAwareOutputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

// Members declared in org.apache.spark.sql.execution.CodegenSupport
protected def doProduce(ctx: CodegenContext): String = throw new UnsupportedOperationException()

// Members declared in org.apache.spark.sql.catalyst.plans.QueryPlan
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override def output: Seq[Attribute] = resAttributes

// Members declared in org.apache.spark.sql.execution.SparkPlan
protected override def doExecute()
Expand Down Expand Up @@ -398,30 +412,7 @@ case class ColumnarHashAggregateExec(
expr.mode match {
case Final =>
val out_res = 0
resultColumnVectors(idx).dataType match {
case t: IntegerType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].intValue)
case t: LongType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].longValue)
case t: DoubleType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].doubleValue())
case t: FloatType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].floatValue())
case t: ByteType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].byteValue())
case t: ShortType =>
resultColumnVectors(idx)
.put(0, out_res.asInstanceOf[Number].shortValue())
case t: StringType =>
val values = (out_res :: Nil).map(_.toByte).toArray
resultColumnVectors(idx)
.putBytes(0, 1, values, 0)
}
putDataIntoVector(resultColumnVectors, out_res, idx)
idx += 1
case _ =>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import scala.collection.mutable.ListBuffer
import scala.util.Random

import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils
import util.control.Breaks._

case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
Expand All @@ -60,6 +61,8 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],

override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute)

buildCheck()

override def requiredChildDistribution: Seq[Distribution] = {
if (partitionSpec.isEmpty) {
// Only show warning when the number of bytes is larger than 100 MiB?
Expand Down Expand Up @@ -91,6 +94,29 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
val sparkConf = sparkContext.getConf
val numaBindingInfo = ColumnarPluginConfig.getConf.numaBindingInfo

def buildCheck(): Unit = {
var allLiteral = true
try {
breakable {
for (func <- validateWindowFunctions()) {
for (child <- func._2.children) {
if (!child.isInstanceOf[Literal]) {
allLiteral = false
break
}
}
}
}
} catch {
case e: Throwable =>
throw new UnsupportedOperationException(s"${e.getMessage}")
}
if (allLiteral) {
throw new UnsupportedOperationException(
s"Window functions' children all being Literal is not supported.")
}
}

def checkAggFunctionSpec(windowSpec: WindowSpecDefinition): Unit = {
if (windowSpec.orderSpec.nonEmpty) {
throw new UnsupportedOperationException("unsupported operation for " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class DateTimeSuite extends QueryTest with SharedSparkSession {
}

// FIXME ZONE issue
ignore("date type - cast from timestamp") {
test("date type - cast from timestamp") {
withTempView("dates") {
val dates = (0L to 3L).map(i => i * 24 * 1000 * 3600)
.map(i => Tuple1(new Timestamp(i)))
Expand Down Expand Up @@ -248,7 +248,7 @@ class DateTimeSuite extends QueryTest with SharedSparkSession {
}

// todo: fix field/literal implicit conversion in ColumnarExpressionConverter
ignore("date type - join on, bhj") {
test("date type - join on, bhj") {
withTempView("dates1", "dates2") {
val dates1 = (0L to 3L).map(i => i * 1000 * 3600 * 24)
.map(i => Tuple1(new Date(i))).toDF("time1")
Expand Down Expand Up @@ -750,7 +750,7 @@ class DateTimeSuite extends QueryTest with SharedSparkSession {
}
}

ignore("datetime function - to_date with format") { // todo GetTimestamp IS PRIVATE ?
test("datetime function - to_date with format") { // todo GetTimestamp IS PRIVATE ?
withTempView("dates") {

val dates = Seq("2009-07-30", "2009-07-31", "2009-08-01")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
super.afterAll()
}

ignore("window queries") {
test("window queries") {
runner.runTPCQuery("q12", 1, true)
runner.runTPCQuery("q20", 1, true)
runner.runTPCQuery("q36", 1, true)
Expand Down Expand Up @@ -103,7 +103,7 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
df.show()
}

ignore("window function with decimal input") {
test("window function with decimal input") {
val df = spark.sql("SELECT i_item_sk, i_class_id, SUM(i_current_price)" +
" OVER (PARTITION BY i_class_id) FROM item LIMIT 1000")
df.explain()
Expand All @@ -118,7 +118,7 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
df.show()
}

ignore("window function with decimal input 2") {
test("window function with decimal input 2") {
val df = spark.sql("SELECT i_item_sk, i_class_id, RANK()" +
" OVER (PARTITION BY i_class_id ORDER BY i_current_price) FROM item LIMIT 1000")
df.explain()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ArrowColumnarBatchSerializerSuite extends SparkFunSuite with SharedSparkSe
SQLMetrics.createAverageMetric(spark.sparkContext, "test serializer number of output rows")
}

ignore("deserialize all null") {
test("deserialize all null") {
val input = getTestResourcePath("test-data/native-splitter-output-all-null")
val serializer =
new ArrowColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance()
Expand Down Expand Up @@ -64,7 +64,7 @@ class ArrowColumnarBatchSerializerSuite extends SparkFunSuite with SharedSparkSe
deserializedStream.close()
}

ignore("deserialize nullable string") {
test("deserialize nullable string") {
val input = getTestResourcePath("test-data/native-splitter-output-nullable-string")
val serializer =
new ArrowColumnarBatchSerializer(avgBatchNumRows, outputNumRows).newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ class DataFrameAggregateSuite extends QueryTest
}

Seq(true, false).foreach { value =>
ignore(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest
Row("b", 2, 4, 8)))
}

ignore("null inputs") {
test("null inputs") {
val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2))
.toDF("key", "value")
val window = Window.orderBy()
Expand Down Expand Up @@ -908,7 +908,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}
}

ignore("NaN and -0.0 in window partition keys") {
test("NaN and -0.0 in window partition keys") {
val df = Seq(
(Float.NaN, Double.NaN),
(0.0f/0.0f, 0.0/0.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession {
assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
}

ignore("function current_timestamp and now") {
test("function current_timestamp and now") {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ import org.apache.spark.sql.internal.SQLConf
* }}}
*/
// scalastyle:on line.size.limit
@deprecated("This test suite is not suitable for native sql engine.", "Mo Rui")
// This test suite is not suitable for native sql engine. (Mo Rui)
/*
trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite {
private val originalMaxToStringFields = conf.maxToStringFields
Expand Down Expand Up @@ -338,3 +339,4 @@ class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite {
}
}
}
*/
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class ReuseExchangeSuite extends RepartitionSuite {

override lazy val input = spark.read.parquet(filePath)

ignore("columnar exchange same result") {
test("columnar exchange same result") {
val df1 = input.groupBy("n_regionkey").agg(Map("n_nationkey" -> "sum"))
val hashAgg1 = df1.queryExecution.executedPlan.collectFirst {
case agg: ColumnarHashAggregateExec => agg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ class NativeDataFrameAggregateSuite extends QueryTest
}

Seq(true, false).foreach { value =>
ignore(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class NativeTPCHTableRepartitionSuite extends NativeRepartitionSuite {

override lazy val input = spark.read.format("arrow").load(filePath)

/*
ignore("tpch table round robin partitioning") {
withRepartition(df => df.repartition(2))
}
Expand All @@ -95,6 +96,7 @@ class NativeTPCHTableRepartitionSuite extends NativeRepartitionSuite {
df => df.groupBy("n_regionkey").agg(Map("n_nationkey" -> "sum")),
df => df.repartition(2))
}
*/
}

class NativeDisableColumnarShuffleSuite extends NativeRepartitionSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class NativeSQLConvertedSuite extends QueryTest
Row(null, 9)))
}

ignore("SMJ") {
test("SMJ") {
Seq[(String, Integer, Integer, Long, Double, Double, Double, Timestamp, Date)](
("val1a", 6, 8, 10L, 15.0, 20D, 20E2, Timestamp.valueOf("2014-04-04 00:00:00.000"), Date.valueOf("2014-04-04")),
("val1b", 8, 16, 19L, 17.0, 25D, 26E2, Timestamp.valueOf("2014-05-04 01:01:00.000"), Date.valueOf("2014-05-04")),
Expand Down Expand Up @@ -227,7 +227,7 @@ class NativeSQLConvertedSuite extends QueryTest
Row(0.0)))
}

ignore("int4 and int8 exception") {
test("int4 and int8 exception") {
Seq(0, 123456, -123456, 2147483647, -2147483647)
.toDF("f1").createOrReplaceTempView("INT4_TBL")
val df = sql("SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i")
Expand All @@ -250,7 +250,7 @@ class NativeSQLConvertedSuite extends QueryTest
df.show()
}

ignore("two inner joins with condition") {
test("two inner joins with condition") {
spark
.read
.format("csv")
Expand Down Expand Up @@ -285,10 +285,16 @@ class NativeSQLConvertedSuite extends QueryTest
"where b.f1 = t.thousand and a.f1 = b.f1 and (a.f1+b.f1+999) = t.tenthous")
checkAnswer(df, Seq())

/** window_part1 -- window has incorrect result */
/** join -- SMJ left semi */
val df2 = sql("select count(*) from tenk1 a where unique1 in" +
" (select unique1 from tenk1 b join tenk1 c using (unique1) where b.unique2 = 42)")
checkAnswer(df2, Seq(Row(1)))
}

ignore("window incorrect result") {
/** window_part1 */
val df1 = sql("SELECT sum(unique1) over (rows between current row and unbounded following)," +
"unique1, four FROM tenk1 WHERE unique1 < 10")
"unique1, four FROM tenk1 WHERE unique1 < 10")
checkAnswer(df1, Seq(
Row(0, 0, 0),
Row(10, 3, 3),
Expand All @@ -300,12 +306,6 @@ class NativeSQLConvertedSuite extends QueryTest
Row(41, 2, 2),
Row(45, 4, 0),
Row(7, 7, 3)))

/** join -- SMJ left semi has segfault */

val df2 = sql("select count(*) from tenk1 a where unique1 in" +
" (select unique1 from tenk1 b join tenk1 c using (unique1) where b.unique2 = 42)")
checkAnswer(df2, Seq(Row(1)))
}

test("min_max") {
Expand Down Expand Up @@ -592,33 +592,6 @@ class NativeSQLConvertedSuite extends QueryTest
}

test("groupingsets") {
spark
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema(
"""
|unique1 int,
|unique2 int,
|two int,
|four int,
|ten int,
|twenty int,
|hundred int,
|thousand int,
|twothousand int,
|fivethous int,
|tenthous int,
|odd int,
|even int,
|stringu1 string,
|stringu2 string,
|string4 string
""".stripMargin)
.load(testFile("test-data/postgresql/tenk.data"))
.write
.format("parquet")
.saveAsTable("tenk1")
val df = sql("select four, x from (select four, ten, 'foo' as x from tenk1) as t" +
" group by grouping sets (four, x) having x = 'foo'")
checkAnswer(df, Seq(Row(null, "foo")))
Expand Down Expand Up @@ -692,7 +665,7 @@ class NativeSQLConvertedSuite extends QueryTest
checkAnswer(df, Seq(Row(1, 1)))
}

ignore("scalar-subquery-select -- SMJ LeftAnti has incorrect result") {
test("scalar-subquery-select -- SMJ LeftAnti has incorrect result") {
Seq[(String, Integer, Integer, Long, Double, Double, Double, Timestamp, Date)](
("val1a", 6, 8, 10L, 15.0, 20D, 20E2, Timestamp.valueOf("2014-04-04 00:00:00.000"), Date.valueOf("2014-04-04")),
("val1b", 8, 16, 19L, 17.0, 25D, 26E2, Timestamp.valueOf("2014-05-04 01:01:00.000"), Date.valueOf("2014-05-04")),
Expand Down Expand Up @@ -756,7 +729,7 @@ class NativeSQLConvertedSuite extends QueryTest
Row("val1e", 10)))
}

test("join") {
// test("join") {
// Seq[(Integer, Integer, String)](
// (1, 4, "one"),
// (2, 3, "two"),
Expand Down Expand Up @@ -814,7 +787,5 @@ class NativeSQLConvertedSuite extends QueryTest
// (4, null))
// .toDF("y1", "y2")
// .createOrReplaceTempView("y")

}

// }
}
Loading

0 comments on commit 06bd109

Please sign in to comment.