diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala index c79bd573260..6870cc203b2 100644 --- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala +++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala @@ -173,6 +173,7 @@ object BenchUtils { var df: DataFrame = null val queryStatus = new ListBuffer[String]() val queryTimes = new ListBuffer[Long]() + val rowCounts = new ListBuffer[Long]() for (i <- 0 until iterations) { spark.sparkContext.setJobDescription(s"Benchmark Run: query=$queryDescription; iteration=$i") @@ -198,7 +199,9 @@ object BenchUtils { df = createDataFrame(spark) resultsAction match { - case Collect() => df.collect() + case Collect() => + val rows = df.collect() + rowCounts.append(rows.length) case WriteCsv(path, mode, options) => ensureValidColumnNames(df).write.mode(mode).options(options).csv(path) case WriteOrc(path, mode, options) => @@ -296,6 +299,7 @@ object BenchUtils { queryDescription, queryPlan, queryPlansWithMetrics, + rowCounts, queryTimes, queryStatus, exceptions) @@ -796,6 +800,7 @@ case class BenchmarkReport( query: String, queryPlan: QueryPlan, queryPlans: Seq[SparkPlanNode], + rowCounts: Seq[Long], queryTimes: Seq[Long], queryStatus: Seq[String], exceptions: Seq[String]) diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala index 866d1213ef5..4fafaab06f4 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala @@ -52,6 +52,7 @@ class BenchUtilsSuite extends FunSuite with BeforeAndAfterEach { query = "q1", queryPlan = QueryPlan("logical", "physical"), Seq.empty, + rowCounts = Seq(10, 10, 10), queryTimes = Seq(99, 88, 77), queryStatus = Seq("Completed", "Completed", "Completed"), exceptions = Seq.empty)