Skip to content

Commit

Permalink
Update GpuRunningWindowExec to use OOM retry framework (#8170)
Browse files Browse the repository at this point in the history
* stub out CheckpointRestore methods for fixers

* implement checkpoint/restore for BatchedRunningWindowFixer impls

* implement checkpoint/restore for BatchedRunningWindowFixer impls

Signed-off-by: Andy Grove <[email protected]>

* retry around fixUpAll

* remove redundant class

* revert intellij auto formatting of imports

* increase section of code contained within withRetryNoSplit

* add some comments

* minor cleanup

* save interim progress

* save interim progress

* close resources in checkpoint restore code

* fix

* add comment

* remove retry from doAggsAndClose

* move retry from doAgg to computeBasicWindow, fix test failures

* fix double close, remove retry from computeBasicWindow

* fix resource leak

* remove comment that is no longer relevant

* fix one resource leak

* Add retry to GpuWindowIterator.hasNext

* defensively reset checkpoints to None during restore

* address feedback

* re-implement first unit test to use iterator

* re-implement unit tests to call GpuWindowIterator

* fix error in test, add more assertions

* fix segfault

* fix test

* revert column order

* remove TODO comment

---------

Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
andygrove authored May 8, 2023
1 parent ea89dd5 commit bf58d90
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 173 deletions.
172 changes: 94 additions & 78 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf
import ai.rapids.cudf.{AggregationOverWindow, DType, GroupByOptions, GroupByScanAggregation, NullPolicy, NvtxColor, ReplacePolicy, ReplacePolicyWithColumn, Scalar, ScanAggregation, ScanType, Table, WindowOptions}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRestoreOnRetry, withRetryNoSplit}
import com.nvidia.spark.rapids.shims.{GpuWindowUtil, ShimUnaryExecNode}

import org.apache.spark.TaskContext
Expand Down Expand Up @@ -1126,36 +1125,15 @@ class GroupedAggregations {
* Do all of the aggregations and put them in the output columns. There may be extra processing
* after this before you get to a final result.
*/
def doAggsAndClose(isRunningBatched: Boolean,
def doAggs(isRunningBatched: Boolean,
boundOrderSpec: Seq[SortOrder],
orderByPositions: Array[Int],
partByPositions: Array[Int],
inputSpillable: SpillableColumnarBatch,
inputCb: ColumnarBatch,
outputColumns: Array[cudf.ColumnVector]): Unit = {
withRetryNoSplit(inputSpillable) { attempt =>
// when there are exceptions in this body, we always want to close
// `outputColumns` before a likely retry.
try {
withResource(attempt.getColumnarBatch()) { attemptCb =>
doRunningWindowOptimizedAggs(
isRunningBatched, partByPositions, attemptCb, outputColumns)
doRowAggs(
boundOrderSpec, orderByPositions, partByPositions, attemptCb, outputColumns)
doRangeAggs(
boundOrderSpec, orderByPositions, partByPositions, attemptCb, outputColumns)
}
} catch {
case t: Throwable =>
// on exceptions we want to throw away any columns in outputColumns that
// are not pass-through
val columnsToClose = outputColumns.filter(_ != null)
outputColumns.indices.foreach { col =>
outputColumns(col) = null
}
columnsToClose.safeClose(t)
throw t
}
}
doRunningWindowOptimizedAggs(isRunningBatched, partByPositions, inputCb, outputColumns)
doRowAggs(boundOrderSpec, orderByPositions, partByPositions, inputCb, outputColumns)
doRangeAggs(boundOrderSpec, orderByPositions, partByPositions, inputCb, outputColumns)
}

/**
Expand Down Expand Up @@ -1252,18 +1230,16 @@ trait BasicWindowCalc {
*/
def computeBasicWindow(cb: ColumnarBatch): Array[cudf.ColumnVector] = {
closeOnExcept(new Array[cudf.ColumnVector](boundWindowOps.length)) { outputColumns =>
val inputSpillable = SpillableColumnarBatch(
GpuProjectExec.project(cb, initialProjections),
SpillPriorities.ACTIVE_BATCHING_PRIORITY)

// this takes ownership of `inputSpillable`
aggregations.doAggsAndClose(
isRunningBatched,
boundOrderSpec,
orderByPositions,
partByPositions,
inputSpillable,
outputColumns)

withResource(GpuProjectExec.project(cb, initialProjections)) { proj =>
aggregations.doAggs(
isRunningBatched,
boundOrderSpec,
orderByPositions,
partByPositions,
proj,
outputColumns)
}

// if the window aggregates were successful, lets splice the passThrough
// columns
Expand Down Expand Up @@ -1299,20 +1275,36 @@ class GpuWindowIterator(

override def isRunningBatched: Boolean = false

override def hasNext: Boolean = input.hasNext
override def hasNext: Boolean = onDeck.isDefined || input.hasNext

var onDeck: Option[SpillableColumnarBatch] = None

override def next(): ColumnarBatch = {
withResource(input.next()) { cb =>
withResource(new NvtxWithMetrics("window", NvtxColor.CYAN, opTime)) { _ =>
val ret = withResource(computeBasicWindow(cb)) { cols =>
convertToBatch(outputTypes, cols)
val cbSpillable = onDeck match {
case Some(x) =>
onDeck = None
x
case _ =>
getNext()
}
withRetryNoSplit(cbSpillable) { _ =>
withResource(cbSpillable.getColumnarBatch()) { cb =>
withResource(new NvtxWithMetrics("window", NvtxColor.CYAN, opTime)) { _ =>
val ret = withResource(computeBasicWindow(cb)) { cols =>
convertToBatch(outputTypes, cols)
}
numOutputBatches += 1
numOutputRows += ret.numRows()
ret
}
numOutputBatches += 1
numOutputRows += ret.numRows()
ret
}
}
}

def getNext(): SpillableColumnarBatch = {
SpillableColumnarBatch(input.next(), SpillPriorities.ACTIVE_BATCHING_PRIORITY)
}

}

object GpuBatchedWindowIterator {
Expand Down Expand Up @@ -1494,34 +1486,59 @@ class GpuRunningWindowIterator(
}
}

def computeRunning(cb: ColumnarBatch): ColumnarBatch = {
def computeRunning(input: ColumnarBatch): ColumnarBatch = {
val fixers = fixerIndexMap
val numRows = cb.numRows()

withResource(computeBasicWindow(cb)) { basic =>
withResource(GpuProjectExec.project(cb, boundPartitionSpec)) { parts =>
val partColumns = GpuColumnVector.extractBases(parts)
withResourceIfAllowed(arePartsEqual(lastParts, partColumns)) { partsEqual =>
val fixedUp = if (fixerNeedsOrderMask) {
withResource(GpuProjectExec.project(cb, boundOrderColumns)) { order =>
val orderColumns = GpuColumnVector.extractBases(order)
// We need to fix up the rows that are part of the same batch as the end of the
// last batch
withResourceIfAllowed(areOrdersEqual(lastOrder, orderColumns, partsEqual)) {
orderEqual =>
closeOnExcept(fixUpAll(basic, fixers, partsEqual, Some(orderEqual))) { fixed =>
saveLastOrder(getScalarRow(numRows - 1, orderColumns))
fixed
val numRows = input.numRows()
val cbSpillable = SpillableColumnarBatch(input, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
withRetryNoSplit(cbSpillable) { _ =>
withResource(cbSpillable.getColumnarBatch()) { cb =>
withResource(computeBasicWindow(cb)) { basic =>
var newOrder: Option[Array[Scalar]] = None
var newParts: Option[Array[Scalar]] = None
val fixedUp = try {
// we backup the fixers state and restore it in the event of a retry
withRestoreOnRetry(fixers.values.toSeq) {
withResource(GpuProjectExec.project(cb,
boundPartitionSpec)) { parts =>
val partColumns = GpuColumnVector.extractBases(parts)
withResourceIfAllowed(arePartsEqual(lastParts, partColumns)) { partsEqual =>
val fixedUp = if (fixerNeedsOrderMask) {
withResource(GpuProjectExec.project(cb,
boundOrderColumns)) { order =>
val orderColumns = GpuColumnVector.extractBases(order)
// We need to fix up the rows that are part of the same batch as the end of
// the last batch
withResourceIfAllowed(areOrdersEqual(lastOrder, orderColumns, partsEqual)) {
orderEqual =>
closeOnExcept(fixUpAll(basic, fixers, partsEqual, Some(orderEqual))) {
fixedUp =>
newOrder = Some(getScalarRow(numRows - 1, orderColumns))
fixedUp
}
}
}
} else {
// No ordering needed
fixUpAll(basic, fixers, partsEqual, None)
}
newParts = Some(getScalarRow(numRows - 1, partColumns))
fixedUp
}
}
}
} else {
// No ordering needed
fixUpAll(basic, fixers, partsEqual, None)
} catch {
case t: Throwable =>
// avoid leaking unused interim results
newOrder.foreach(_.foreach(_.close()))
newParts.foreach(_.foreach(_.close()))
throw t
}
withResource(fixedUp) { fixed =>
saveLastParts(getScalarRow(numRows - 1, partColumns))
convertToBatch(outputTypes, fixed)
// this section is outside of the retry logic because the calls to saveLastParts
// and saveLastOrders can potentially close GPU resources
withResource(fixedUp) { _ =>
newOrder.foreach(saveLastOrder)
newParts.foreach(saveLastParts)
convertToBatch(outputTypes, fixedUp)
}
}
}
Expand Down Expand Up @@ -1555,13 +1572,12 @@ class GpuRunningWindowIterator(
}

override def next(): ColumnarBatch = {
withResource(readNextInputBatch()) { cb =>
withResource(new NvtxWithMetrics("RunningWindow", NvtxColor.CYAN, opTime)) { _ =>
val ret = computeRunning(cb)
numOutputBatches += 1
numOutputRows += ret.numRows()
ret
}
val cb = readNextInputBatch()
withResource(new NvtxWithMetrics("RunningWindow", NvtxColor.CYAN, opTime)) { _ =>
val ret = computeRunning(cb) // takes ownership of cb
numOutputBatches += 1
numOutputRows += ret.numRows()
ret
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ trait GpuRunningWindowFunction extends GpuWindowFunction {
* </code>
* which can be output.
*/
trait BatchedRunningWindowFixer extends AutoCloseable {
trait BatchedRunningWindowFixer extends AutoCloseable with CheckpointRestore {
/**
* Fix up `windowedColumnOutput` with any stored state from previous batches.
* Like all window operations the input data will have been sorted by the partition
Expand Down Expand Up @@ -976,6 +976,26 @@ class BatchedRunningWindowBinaryFixer(val binOp: BinaryOp, val name: String)
extends BatchedRunningWindowFixer with Logging {
private var previousResult: Option[Scalar] = None

// checkpoint
private var checkpointPreviousResult: Option[Scalar] = None

override def checkpoint(): Unit = {
checkpointPreviousResult = previousResult
}

override def restore(): Unit = {
if (checkpointPreviousResult.isDefined) {
// close previous result
previousResult match {
case Some(r) if r != checkpointPreviousResult.get =>
r.close()
case _ =>
}
previousResult = checkpointPreviousResult
checkpointPreviousResult = None
}
}

def getPreviousResult: Option[Scalar] = previousResult

def updateState(finalOutputColumn: cudf.ColumnVector): Unit = {
Expand Down Expand Up @@ -1025,6 +1045,38 @@ class SumBinaryFixer(toType: DataType, isAnsi: Boolean)
private var previousResult: Option[Scalar] = None
private var previousOverflow: Option[Scalar] = None

// checkpoint
private var checkpointResult: Option[Scalar] = None
private var checkpointOverflow: Option[Scalar] = None

override def checkpoint(): Unit = {
checkpointOverflow = previousOverflow
checkpointResult = previousResult
}

override def restore(): Unit = {
if (checkpointOverflow.isDefined) {
// close previous result
previousOverflow match {
case Some(r) if r != checkpointOverflow.get =>
r.close()
case _ =>
}
previousOverflow = checkpointOverflow
checkpointOverflow = None
}
if (checkpointResult.isDefined) {
// close previous result
previousResult match {
case Some(r) if r != checkpointResult.get =>
r.close()
case _ =>
}
previousResult = checkpointResult
checkpointResult = None
}
}

def updateState(finalOutputColumn: cudf.ColumnVector,
wasOverflow: Option[cudf.ColumnVector]): Unit = {
val lastIndex = finalOutputColumn.getRowCount.toInt - 1
Expand Down Expand Up @@ -1255,6 +1307,28 @@ class RankFixer extends BatchedRunningWindowFixer with Logging {
// The previous rank value
private[this] var previousRank: Option[Scalar] = None

// checkpoint
private[this] var checkpointRank: Option[Scalar] = None

override def checkpoint(): Unit = {
rowNumFixer.checkpoint()
checkpointRank = previousRank
}

override def restore(): Unit = {
rowNumFixer.restore()
if (checkpointRank.isDefined) {
// close previous result
previousRank match {
case Some(r) if r != checkpointRank.get =>
r.close()
case _ =>
}
previousRank = checkpointRank
checkpointRank = None
}
}

override def needsOrderMask: Boolean = true

override def fixUp(
Expand Down Expand Up @@ -1353,6 +1427,26 @@ class DenseRankFixer extends BatchedRunningWindowFixer with Logging {

private var previousRank: Option[Scalar] = None

// checkpoint
private var checkpointRank: Option[Scalar] = None

override def checkpoint(): Unit = {
checkpointRank = previousRank
}

override def restore(): Unit = {
if (checkpointRank.isDefined) {
// close previous result
previousRank match {
case Some(r) if r != checkpointRank.get =>
r.close()
case _ =>
}
previousRank = checkpointRank
checkpointRank = None
}
}

override def needsOrderMask: Boolean = true

override def fixUp(
Expand Down
Loading

0 comments on commit bf58d90

Please sign in to comment.