Skip to content

Commit

Permalink
Clean up GpuCollectLimitMeta and add in metrics (NVIDIA#289)
Browse files Browse the repository at this point in the history
* Clean up GpuCollectLimitMeta and add in metrics

* Addressed review comments
  • Loading branch information
revans2 authored Jun 25, 2020
1 parent b9ca74c commit ae9f794
Showing 1 changed file with 19 additions and 40 deletions.
59 changes: 19 additions & 40 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.Table
import ai.rapids.cudf.{NvtxColor, Table}
import com.nvidia.spark.rapids.GpuMetricNames._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

Expand Down Expand Up @@ -47,6 +47,10 @@ trait GpuBaseLimitExec extends LimitExec with GpuExec {
throw new IllegalStateException(s"Row-based execution should not occur for $this")

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric(NUM_OUTPUT_ROWS)
val numOutputBatches = longMetric(NUM_OUTPUT_BATCHES)
val totalTime = longMetric(TOTAL_TIME)

val crdd = child.executeColumnar()
crdd.mapPartitions { cbIter =>
new Iterator[ColumnarBatch] {
Expand All @@ -56,13 +60,17 @@ trait GpuBaseLimitExec extends LimitExec with GpuExec {

override def next(): ColumnarBatch = {
val batch = cbIter.next()
val result = if (batch.numRows() > remainingLimit) {
sliceBatch(batch)
} else {
batch
withResource(new NvtxWithMetrics("limit", NvtxColor.ORANGE, totalTime)) { _ =>
val result = if (batch.numRows() > remainingLimit) {
sliceBatch(batch)
} else {
batch
}
numOutputBatches += 1
numOutputRows += result.numRows()
remainingLimit -= result.numRows()
result
}
remainingLimit -= result.numRows()
result
}

def sliceBatch(batch: ColumnarBatch): ColumnarBatch = {
Expand Down Expand Up @@ -123,37 +131,8 @@ class GpuCollectLimitMeta(
Seq(GpuOverrides.wrapPart(collectLimit.outputPartitioning, conf, Some(this)))

override def convertToGpu(): GpuExec =
GpuCollectLimitExec(collectLimit.limit, childParts(0).convertToGpu(),
GpuLocalLimitExec(collectLimit.limit,
GpuShuffleExchangeExec(GpuSinglePartitioning(Seq.empty), childPlans(0).convertIfNeeded())))
}

case class GpuCollectLimitExec(
limit: Int, partitioning: GpuPartitioning,
child: SparkPlan) extends GpuBaseLimitExec {

private val serializer: Serializer = new GpuColumnarBatchSerializer(child.output.size)

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
GpuGlobalLimitExec(collectLimit.limit,
GpuShuffleExchangeExec(GpuSinglePartitioning(Seq.empty),
GpuLocalLimitExec(collectLimit.limit, childPlans(0).convertIfNeeded())))

lazy val shuffleMetrics = readMetrics ++ writeMetrics

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val locallyLimited: RDD[ColumnarBatch] = super.doExecuteColumnar()

val shuffleDependency = GpuShuffleExchangeExec.prepareBatchShuffleDependency(
locallyLimited,
child.output,
partitioning,
serializer,
metrics ++ shuffleMetrics,
metrics ++ writeMetrics)

val shuffled = new ShuffledBatchRDD(shuffleDependency, metrics ++ shuffleMetrics, None)
shuffled.mapPartitions(_.take(limit))
}

}
}

0 comments on commit ae9f794

Please sign in to comment.