diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 2b6bccdbc1..b2eef5d09b 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil * The input iterators producing sequence of batches of Arrow Arrays. * @param protobufQueryPlan * The serialized bytes of Spark execution plan. + * @param numParts + * The number of partitions. + * @param partitionIndex + * The index of the partition. */ class CometExecIterator( val id: Long, inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode) + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIndex: Int) extends Iterator[ColumnarBatch] { private val nativeLib = new Native() @@ -55,13 +61,12 @@ class CometExecIterator( }.toArray private val plan = { val configs = createNativeConf - TaskContext.get().numPartitions() nativeLib.createPlan( id, configs, cometBatchIterators, protobufQueryPlan, - TaskContext.get().numPartitions(), + numParts, nativeMetrics, new CometTaskMemoryManager(id)) } @@ -103,10 +108,12 @@ class CometExecIterator( } def getNextBatch(): Option[ColumnarBatch] = { + assert(partitionIndex >= 0 && partitionIndex < numParts) + nativeUtil.getNextBatch( numOutputCols, (arrayAddrs, schemaAddrs) => { - nativeLib.executePlan(plan, TaskContext.get().partitionId(), arrayAddrs, schemaAddrs) + nativeLib.executePlan(plan, partitionIndex, arrayAddrs, schemaAddrs) }) } diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index d88f129a38..1028d04660 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -1029,12 +1029,20 @@ class CometSparkSessionExtensions var firstNativeOp = true newPlan.transformDown { case op: CometNativeExec => - if (firstNativeOp) { + val newPlan = if (firstNativeOp) { firstNativeOp = false op.convertBlock() } else { op } + + // If reaching leaf node, reset `firstNativeOp` to true + // because it will start a new block in next iteration. + if (op.children.isEmpty) { + firstNativeOp = true + } + + newPlan case op => firstNativeOp = true op diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 952515d5fc..2fd7f12c24 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch private[spark] class CometExecRDD( sc: SparkContext, partitionNum: Int, - var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]) + var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) extends RDD[ColumnarBatch](sc, Nil) { override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - f(Seq.empty) + f(Seq.empty, partitionNum, s.index) } override protected def getPartitions: Array[Partition] = { @@ -46,7 +46,8 @@ private[spark] class CometExecRDD( object CometExecRDD { def apply(sc: SparkContext, partitionNum: Int)( - f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] = + f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) + : RDD[ColumnarBatch] = withScope(sc) { new CometExecRDD(sc, partitionNum, f) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 8cc03856c2..9698dc98b8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -51,9 +51,10 @@ object CometExecUtils { childPlan: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], limit: Int): RDD[ColumnarBatch] = { - childPlan.mapPartitionsInternal { iter => + val numParts = childPlan.getNumPartitions + childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get - CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp) + CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 6220c809da..5582f4d687 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec( val localTopK = if (orderingSatisfies) { CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit) } else { - childRDD.mapPartitionsInternal { iter => + val numParts = childRDD.getNumPartitions + childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => val topK = CometExecUtils .getTopKNativePlan(child.output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), child.output.length, topK) + CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx) } } @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, child.output, sortOrder, child, limit) .get - val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection) + val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0) setSubqueries(it.id, this) Option(TaskContext.get()).foreach { context => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala index 6db8c67d58..fdf8bf393d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ private[spark] class ZippedPartitionsRDD( sc: SparkContext, - var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch], + var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch], var zipRdds: Seq[RDD[ColumnarBatch]], preservesPartitioning: Boolean = false) extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) { + // We need to get the number of partitions in `compute` but `getNumPartitions` is not available + // on the executors. So we need to capture it here. + private val numParts: Int = this.getNumPartitions + override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions val iterators = zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context)) - f(iterators) + f(iterators, numParts, s.index) } override def clearDependencies(): Unit = { @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD( object ZippedPartitionsRDD { def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])( - f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] = + f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) + : RDD[ColumnarBatch] = withScope(sc) { new ZippedPartitionsRDD(sc, f, rdds) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 6430a7899f..388c07a27e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -224,13 +224,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val numParts = rdd.getNumPartitions val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( rdd.map( (0, _) ), // adding fake partitionId that is always 0 because ShuffleDependency requires it serializer = serializer, shuffleWriterProcessor = - new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics), + new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts), shuffleType = CometNativeShuffle, partitioner = new Partitioner { override def numPartitions: Int = outputPartitioning.numPartitions @@ -446,7 +447,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { class CometShuffleWriteProcessor( outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], - metrics: Map[String, SQLMetric]) + metrics: Map[String, SQLMetric], + numParts: Int) extends ShimCometShuffleWriteProcessor { private val OFFSET_LENGTH = 8 @@ -489,7 +491,9 @@ class CometShuffleWriteProcessor( Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), outputAttributes.length, nativePlan, - nativeMetrics) + nativeMetrics, + numParts, + context.partitionId()) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8b50ad191d..9ca1f94de9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -120,20 +120,37 @@ object CometExec { def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, - nativePlan: Operator): CometExecIterator = { - getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty)) + nativePlan: Operator, + numParts: Int, + partitionIdx: Int): CometExecIterator = { + getCometIterator( + inputs, + numOutputCols, + nativePlan, + CometMetricNode(Map.empty), + numParts, + partitionIdx) } def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, nativePlan: Operator, - nativeMetrics: CometMetricNode): CometExecIterator = { + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIdx: Int): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics) + new CometExecIterator( + newIterId, + inputs, + numOutputCols, + bytes, + nativeMetrics, + numParts, + partitionIdx) } /** @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec { // TODO: support native metrics for all operators. val nativeMetrics = CometMetricNode.fromCometPlan(this) - def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = { + def createCometExecIter( + inputs: Seq[Iterator[ColumnarBatch]], + numParts: Int, + partitionIndex: Int): CometExecIterator = { val it = new CometExecIterator( CometExec.newIterId, inputs, output.length, serializedPlanCopy, - nativeMetrics) + nativeMetrics, + numParts, + partitionIndex) setSubqueries(it.id, this) @@ -315,10 +337,10 @@ abstract class CometNativeExec extends CometExec { } if (inputs.nonEmpty) { - ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_)) + ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) } else { val partitionNum = firstNonBroadcastPlanNumPartitions.get - CometExecRDD(sparkContext, partitionNum)(createCometExecIter(_)) + CometExecRDD(sparkContext, partitionNum)(createCometExecIter) } } } diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index ef0485dfe6..6ca38e8319 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase { override def next(): ColumnarBatch = throw new NullPointerException() }), 1, - limitOp) + limitOp, + 1, + 0) cometIter.next() cometIter.close() value