Skip to content

Commit

Permalink
fix: Use RDD partition index (apache#1112)
Browse files Browse the repository at this point in the history
* fix: Use RDD partition index

* fix

* fix

* fix
  • Loading branch information
viirya committed Nov 25, 2024
1 parent c3ad26e commit 05c1cc5
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 28 deletions.
15 changes: 11 additions & 4 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
*/
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode)
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int)
extends Iterator[ColumnarBatch] {

private val nativeLib = new Native()
Expand All @@ -55,13 +61,12 @@ class CometExecIterator(
}.toArray
private val plan = {
val configs = createNativeConf
TaskContext.get().numPartitions()
nativeLib.createPlan(
id,
configs,
cometBatchIterators,
protobufQueryPlan,
TaskContext.get().numPartitions(),
numParts,
nativeMetrics,
new CometTaskMemoryManager(id))
}
Expand Down Expand Up @@ -103,10 +108,12 @@ class CometExecIterator(
}

def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)

nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
nativeLib.executePlan(plan, TaskContext.get().partitionId(), arrayAddrs, schemaAddrs)
nativeLib.executePlan(plan, partitionIndex, arrayAddrs, schemaAddrs)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,12 +1029,20 @@ class CometSparkSessionExtensions
var firstNativeOp = true
newPlan.transformDown {
case op: CometNativeExec =>
if (firstNativeOp) {
val newPlan = if (firstNativeOp) {
firstNativeOp = false
op.convertBlock()
} else {
op
}

// If reaching leaf node, reset `firstNativeOp` to true
// because it will start a new block in next iteration.
if (op.children.isEmpty) {
firstNativeOp = true
}

newPlan
case op =>
firstNativeOp = true
op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
private[spark] class CometExecRDD(
sc: SparkContext,
partitionNum: Int,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch])
var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
extends RDD[ColumnarBatch](sc, Nil) {

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
f(Seq.empty)
f(Seq.empty, partitionNum, s.index)
}

override protected def getPartitions: Array[Partition] = {
Expand All @@ -46,7 +46,8 @@ private[spark] class CometExecRDD(

object CometExecRDD {
def apply(sc: SparkContext, partitionNum: Int)(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
: RDD[ColumnarBatch] =
withScope(sc) {
new CometExecRDD(sc, partitionNum, f)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ object CometExecUtils {
childPlan: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val numParts = childPlan.getNumPartitions
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec(
val localTopK = if (orderingSatisfies) {
CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
} else {
childRDD.mapPartitionsInternal { iter =>
val numParts = childRDD.getNumPartitions
childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
val topK =
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), child.output.length, topK)
CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx)
}
}

Expand All @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit)
.get
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
private[spark] class ZippedPartitionsRDD(
sc: SparkContext,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch],
var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch],
var zipRdds: Seq[RDD[ColumnarBatch]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) {

// We need to get the number of partitions in `compute` but `getNumPartitions` is not available
// on the executors. So we need to capture it here.
private val numParts: Int = this.getNumPartitions

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
val iterators =
zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context))
f(iterators)
f(iterators, numParts, s.index)
}

override def clearDependencies(): Unit = {
Expand All @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD(

object ZippedPartitionsRDD {
def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
: RDD[ColumnarBatch] =
withScope(sc) {
new ZippedPartitionsRDD(sc, f, rdds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
outputPartitioning: Partitioning,
serializer: Serializer,
metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
val numParts = rdd.getNumPartitions
val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
rdd.map(
(0, _)
), // adding fake partitionId that is always 0 because ShuffleDependency requires it
serializer = serializer,
shuffleWriterProcessor =
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics),
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts),
shuffleType = CometNativeShuffle,
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
Expand Down Expand Up @@ -446,7 +447,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
class CometShuffleWriteProcessor(
outputPartitioning: Partitioning,
outputAttributes: Seq[Attribute],
metrics: Map[String, SQLMetric])
metrics: Map[String, SQLMetric],
numParts: Int)
extends ShimCometShuffleWriteProcessor {

private val OFFSET_LENGTH = 8
Expand Down Expand Up @@ -489,7 +491,9 @@ class CometShuffleWriteProcessor(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
nativePlan,
nativeMetrics)
nativeMetrics,
numParts,
context.partitionId())

while (cometIter.hasNext) {
cometIter.next()
Expand Down
38 changes: 30 additions & 8 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,37 @@ object CometExec {
def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator): CometExecIterator = {
getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty))
nativePlan: Operator,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
getCometIterator(
inputs,
numOutputCols,
nativePlan,
CometMetricNode(Map.empty),
numParts,
partitionIdx)
}

def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator,
nativeMetrics: CometMetricNode): CometExecIterator = {
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
val outputStream = new ByteArrayOutputStream()
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics)
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
bytes,
nativeMetrics,
numParts,
partitionIdx)
}

/**
Expand Down Expand Up @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec {
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)

def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = {
def createCometExecIter(
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val it = new CometExecIterator(
CometExec.newIterId,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics)
nativeMetrics,
numParts,
partitionIndex)

setSubqueries(it.id, this)

Expand Down Expand Up @@ -315,10 +337,10 @@ abstract class CometNativeExec extends CometExec {
}

if (inputs.nonEmpty) {
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
} else {
val partitionNum = firstNonBroadcastPlanNumPartitions.get
CometExecRDD(sparkContext, partitionNum)(createCometExecIter(_))
CometExecRDD(sparkContext, partitionNum)(createCometExecIter)
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase {
override def next(): ColumnarBatch = throw new NullPointerException()
}),
1,
limitOp)
limitOp,
1,
0)
cometIter.next()
cometIter.close()
value
Expand Down

0 comments on commit 05c1cc5

Please sign in to comment.