From 5f468cc21ef621151c200edfeea0411342c6d8bb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 11 Sep 2020 09:11:35 +0900 Subject: [PATCH] [SPARK-32822][SQL] Change the number of partitions to zero when a range is empty with WholeStageCodegen disabled or falled back ### What changes were proposed in this pull request? This PR changes the behavior of RangeExec with WholeStageCodegen disabled or falled back to change the number of partitions to zero when a range is empty. In the current master, if WholeStageCodegen effects, the number of partitions of an empty range will be changed to zero. ``` spark.range(1, 1, 1, 1000).rdd.getNumPartitions res0: Int = 0 ``` But it doesn't if WholeStageCodegen is disabled or falled back. ``` spark.conf.set("spark.sql.codegen.wholeStage", false) spark.range(1, 1, 1, 1000).rdd.getNumPartitions res2: Int = 1000 ``` ### Why are the changes needed? To archive better performance even though WholeStageCodegen disabled or falled back. ### Does this PR introduce _any_ user-facing change? Yes. the number of partitions gotten with `getNumPartitions` for an empty range will be changed when WholeStageCodegen is disabled. ### How was this patch tested? New test. Closes #29681 from sarutak/zero-size-range. Authored-by: Kousuke Saruta Signed-off-by: Takeshi Yamamuro --- .../execution/basicPhysicalOperators.scala | 105 ++++++++++-------- .../spark/sql/execution/PlannerSuite.scala | 7 ++ 2 files changed, 63 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index c240a182d32bb..1f70fde3f7654 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -371,6 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val step: Long = range.step val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) val numElements: BigInt = range.numElements + val isEmptyRange: Boolean = start == end || (start < end ^ 0 < step) override val output: Seq[Attribute] = range.output @@ -396,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } override def inputRDDs(): Seq[RDD[InternalRow]] = { - val rdd = if (start == end || (start < end ^ 0 < step)) { + val rdd = if (isEmptyRange) { new EmptyRDD[InternalRow](sqlContext.sparkContext) } else { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) @@ -562,58 +563,64 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - sqlContext - .sparkContext - .parallelize(0 until numSlices, numSlices) - .mapPartitionsWithIndex { (i, _) => - val partitionStart = (i * numElements) / numSlices * step + start - val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start - def getSafeMargin(bi: BigInt): Long = - if (bi.isValidLong) { - bi.toLong - } else if (bi > 0) { - Long.MaxValue - } else { - Long.MinValue - } - val safePartitionStart = getSafeMargin(partitionStart) - val safePartitionEnd = getSafeMargin(partitionEnd) - val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize - val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) - val taskContext = TaskContext.get() - - val iter = new Iterator[InternalRow] { - private[this] var number: Long = safePartitionStart - private[this] var overflow: Boolean = false - private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics - - override def hasNext = - if (!overflow) { - if (step > 0) { - number < safePartitionEnd - } else { - number > safePartitionEnd - } - } else false - - override def next() = { - val ret = number - number += step - if (number < ret ^ step < 0) { - // we have Long.MaxValue + Long.MaxValue < Long.MaxValue - // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step - // back, we are pretty sure that we have an overflow. - overflow = true + if (isEmptyRange) { + new EmptyRDD[InternalRow](sqlContext.sparkContext) + } else { + sqlContext + .sparkContext + .parallelize(0 until numSlices, numSlices) + .mapPartitionsWithIndex { (i, _) => + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue } - numOutputRows += 1 - inputMetrics.incRecordsRead(1) - unsafeRow.setLong(0, ret) - unsafeRow + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize + val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) + val taskContext = TaskContext.get() + + val iter = new Iterator[InternalRow] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + + numOutputRows += 1 + inputMetrics.incRecordsRead(1) + unsafeRow.setLong(0, ret) + unsafeRow + } } + new InterruptibleIterator(taskContext, iter) } - new InterruptibleIterator(taskContext, iter) - } + } } override def simpleString(maxFields: Int): String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d428b7ebc0e91..ca52e51c87ea7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -994,6 +994,13 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } } + + testWithWholeStageCodegenOnAndOff("Change the number of partitions to zero " + + "when a range is empty") { _ => + val range = spark.range(1, 1, 1, 1000) + val numPartitions = range.rdd.getNumPartitions + assert(numPartitions == 0) + } } // Used for unit-testing EnsureRequirements