Skip to content

Commit

Permalink
Fix a hang for Pandas UDFs on DB 13.3[databricks] (NVIDIA#9833)
Browse files Browse the repository at this point in the history
fix NVIDIA#9493
fix NVIDIA#9844

The python runner uses two separate threads to write and read data with Python processes, 
however on DB13.3, it becomes single-threaded, which means reading and writing run on the same thread.
Now the first reading is always ahead of the first writing. But the original BatchQueue will wait
on the first reading until the first writing is done. Then it will wait forever.

Change made:

- Update the BatchQueue to support asking for a batch instead of waiting unitl one is inserted into the queue. 
   This can eliminate the order requirement of reading and writing.
- Introduce a new class named BatchProducer to work with the new BatchQueue to support rows number
   peek on demand for the reading.
- Apply this new BatchQueue to relevant plans.
- Update the Python runners to support writing one batch one time for the singled-threaded model.
- Found an issue about PythonUDAF and RunningWindoFunctionExec, it may be a bug specific to DB 13.3,
   and add a test (test_window_aggregate_udf_on_cpu) for it.
- Other small refactors
---------

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman authored and razajafri committed Jan 25, 2024
1 parent 764a923 commit c6496bd
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 302 deletions.
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def is_spark_330_or_later():
def is_spark_340_or_later():
return spark_version() >= "3.4.0"

def is_spark_341():
return spark_version() == "3.4.1"

def is_spark_350_or_later():
return spark_version() >= "3.5.0"

Expand Down
17 changes: 7 additions & 10 deletions integration_tests/src/main/python/udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from conftest import is_at_least_precommit_run, is_not_utc
from spark_session import is_databricks_runtime, is_before_spark_330, is_before_spark_350, is_spark_340_or_later
from spark_session import is_databricks_runtime, is_before_spark_330, is_before_spark_350, is_spark_341

from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version

Expand Down Expand Up @@ -43,12 +43,6 @@
import pyarrow
from typing import Iterator, Tuple


if is_databricks_runtime() and is_spark_340_or_later():
# Databricks 13.3 does not use separate reader/writer threads for Python UDFs
# which can lead to hangs. Skipping these tests until the Python UDF handling is updated.
pytestmark = pytest.mark.skip(reason="https://github.com/NVIDIA/spark-rapids/issues/9493")

arrow_udf_conf = {
'spark.sql.execution.arrow.pyspark.enabled': 'true',
'spark.rapids.sql.exec.WindowInPandasExec': 'true',
Expand Down Expand Up @@ -182,7 +176,10 @@ def group_size_udf(to_process: pd.Series) -> int:

low_upper_win = Window.partitionBy('a').orderBy('b').rowsBetween(-3, 3)

udf_windows = [no_part_win, unbounded_win, cur_follow_win, pre_cur_win, low_upper_win]
running_win_param = pytest.param(pre_cur_win, marks=pytest.mark.xfail(
condition=is_databricks_runtime() and is_spark_341(),
reason='DB13.3 wrongly uses RunningWindowFunctionExec to evaluate a PythonUDAF and it will fail even on CPU'))
udf_windows = [no_part_win, unbounded_win, cur_follow_win, running_win_param, low_upper_win]
window_ids = ['No_Partition', 'Unbounded', 'Unbounded_Following', 'Unbounded_Preceding',
'Lower_Upper']

Expand Down Expand Up @@ -338,8 +335,8 @@ def create_df(spark, data_gen, left_length, right_length):
@ignore_order
@pytest.mark.parametrize('data_gen', [ShortGen(nullable=False)], ids=idfn)
def test_cogroup_apply_udf(data_gen):
def asof_join(l, r):
return pd.merge_asof(l, r, on='a', by='b')
def asof_join(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.merge_ordered(left, right)

def do_it(spark):
left, right = create_df(spark, data_gen, 500, 500)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import ai.rapids.cudf
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.rapids.execution.GpuSubPartitionHashJoin
import org.apache.spark.sql.rapids.execution.python.shims.GpuPythonArrowOutput
import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -398,34 +400,68 @@ class CombiningIterator(
numOutputRows: GpuMetric,
numOutputBatches: GpuMetric) extends Iterator[ColumnarBatch] {

// For `hasNext` we are waiting on the queue to have something inserted into it
// instead of waiting for a result to be ready from Python. The reason for this
// is to let us know the target number of rows in the batch that we want when reading.
// It is a bit hacked up but it works. In the future when we support spilling we should
// store the number of rows separate from the batch. That way we can get the target batch
// size out without needing to grab the GpuSemaphore which we cannot do if we might block
// on a read operation.
override def hasNext: Boolean = inputBatchQueue.hasNext || pythonOutputIter.hasNext
// This is only for the input.
private var pendingInput: Option[SpillableColumnarBatch] = None
Option(TaskContext.get()).foreach(onTaskCompletion(_)(pendingInput.foreach(_.close())))

// The Python output should line up row for row so we only look at the Python output
// iterator and no need to check the `inputPending` who will be consumed when draining
// the Python output.
override def hasNext: Boolean = pythonOutputIter.hasNext

override def next(): ColumnarBatch = {
val numRows = inputBatchQueue.peekBatchSize
val numRows = inputBatchQueue.peekBatchNumRows()
// Updates the expected batch size for next read
pythonArrowReader.setMinReadTargetBatchSize(numRows)
pythonArrowReader.setMinReadTargetNumRows(numRows)
// Reads next batch from Python and combines it with the input batch by the left side.
withResource(pythonOutputIter.next()) { cbFromPython =>
assert(cbFromPython.numRows() == numRows)
withResource(inputBatchQueue.remove()) { origBatch =>
// Here may get a batch has a larger rows number than the current input batch.
assert(cbFromPython.numRows() >= numRows,
s"Expects >=$numRows rows but got ${cbFromPython.numRows()} from the Python worker")
withResource(concatInputBatch(cbFromPython.numRows())) { concated =>
numOutputBatches += 1
numOutputRows += numRows
combine(origBatch, cbFromPython)
GpuColumnVector.combineColumns(concated, cbFromPython)
}
}
}

private def combine(lBatch: ColumnarBatch, rBatch: ColumnarBatch): ColumnarBatch = {
val lColumns = GpuColumnVector.extractColumns(lBatch).map(_.incRefCount())
val rColumns = GpuColumnVector.extractColumns(rBatch).map(_.incRefCount())
new ColumnarBatch(lColumns ++ rColumns, lBatch.numRows())
private def concatInputBatch(targetNumRows: Int): ColumnarBatch = {
withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { buf =>
var curNumRows = pendingInput.map(_.numRows()).getOrElse(0)
pendingInput.foreach(buf.append(_))
pendingInput = None
while (curNumRows < targetNumRows) {
val scb = inputBatchQueue.remove()
if (scb != null) {
buf.append(scb)
curNumRows = curNumRows + scb.numRows()
}
}
assert(buf.nonEmpty, "The input queue is empty")

if (curNumRows > targetNumRows) {
// Need to split the last batch
val Array(first, second) = withRetryNoSplit(buf.remove(buf.size - 1)) { lastScb =>
val splitIdx = lastScb.numRows() - (curNumRows - targetNumRows)
withResource(lastScb.getColumnarBatch()) { lastCb =>
val batchTypes = GpuColumnVector.extractTypes(lastCb)
withResource(GpuColumnVector.from(lastCb)) { table =>
table.contiguousSplit(splitIdx).safeMap(
SpillableColumnarBatch(_, batchTypes, SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
}
}
}
buf.append(first)
pendingInput = Some(second)
}

val ret = GpuSubPartitionHashJoin.concatSpillBatchesAndClose(buf.toSeq)
// "ret" should be non empty because we checked the buf is not empty ahead.
withResource(ret.get) { concatedScb =>
concatedScb.getColumnarBatch()
}
} // end of withResource(mutable.ArrayBuffer)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import com.nvidia.spark.rapids.shims.ShimUnaryExecNode

Expand Down Expand Up @@ -141,9 +140,7 @@ case class GpuAggregateInPandasExec(

// Start processing
child.executeColumnar().mapPartitionsInternal { inputIter =>
val queue: BatchQueue = new BatchQueue()
val context = TaskContext.get()
onTaskCompletion(queue.close())

if (isPythonOnGpuEnabled) {
GpuPythonHelper.injectGpuInfo(pyFuncs, isPythonOnGpuEnabled)
Expand All @@ -164,51 +161,56 @@ case class GpuAggregateInPandasExec(
}

// Second splits into separate group batches.
val miniAttrs = gpuGroupingExpressions ++ allInputs
val pyInputIter = BatchGroupedIterator(miniIter, miniAttrs.asInstanceOf[Seq[Attribute]],
groupingRefs.indices)
.map { groupedBatch =>
// Resolves the group key and the python input from a grouped batch. Then
// - Caches the key to be combined with the Python output later. And
// - Returns the python input to be sent to Python later.
withResource(groupedBatch) { grouped =>
// key batch.
// No `safeMap` because here does not increase the ref count.
// (`Seq.indices.map()` is NOT lazy, so it is safe to be used to slice the columns.)
val keyCudfColumns = groupingRefs.indices.map(
grouped.column(_).asInstanceOf[GpuColumnVector].getBase)
val keyBatch = if (keyCudfColumns.isEmpty) {
// No grouping columns, then the whole batch is a group. Returns the dedicated batch
// as the group key.
// This batch means there is only one empty row, just like the 'new UnsafeRow()'
// used in Spark. The row number setting to 1 is because Python returns only one row
// as the aggregate result for the whole batch, and 'CombiningIterator' requires the
// the same row number for both the key batch and the result batch to be combined.
new ColumnarBatch(Array(), 1)
} else {
// Uses `cudf.Table.gather` to pick the first row in each group as the group key.
// Doing this is because
// - The Python worker produces only one row as the aggregate result,
// - The key rows in a group are equal to each other.
//
// (Now this is done group by group, so the performance would not be good when
// there are too many small groups.)
withResource(new cudf.Table(keyCudfColumns: _*)) { table =>
withResource(cudf.ColumnVector.fromInts(0)) { gatherMap =>
withResource(table.gather(gatherMap)) { oneRowTable =>
GpuColumnVector.from(oneRowTable, groupingRefs.map(_.dataType).toArray)
}
}
val miniAttrs = (gpuGroupingExpressions ++ allInputs).asInstanceOf[Seq[Attribute]]
val keyConverter = (groupedBatch: ColumnarBatch) => {
// No `safeMap` because here does not increase the ref count.
// (`Seq.indices.map()` is NOT lazy, so it is safe to be used to slice the columns.)
val keyCudfColumns = groupingRefs.indices.map(
groupedBatch.column(_).asInstanceOf[GpuColumnVector].getBase)
if (keyCudfColumns.isEmpty) {
// No grouping columns, then the whole batch is a group. Returns the dedicated batch
// as the group key.
// This batch means there is only one empty row, just like the 'new UnsafeRow()'
// used in Spark. The row number setting to 1 is because Python returns only one row
// as the aggregate result for the whole batch, and 'CombiningIterator' requires the
// the same row number for both the key batch and the result batch to be combined.
new ColumnarBatch(Array(), 1)
} else {
// Uses `cudf.Table.gather` to pick the first row in each group as the group key.
// Doing this is because
// - The Python worker produces only one row as the aggregate result,
// - The key rows in a group are equal to each other.
//
// (Now this is done group by group, so the performance would not be good when
// there are too many small groups.)
withResource(new cudf.Table(keyCudfColumns: _*)) { table =>
withResource(cudf.ColumnVector.fromInts(0)) { gatherMap =>
withResource(table.gather(gatherMap)) { oneRowTable =>
GpuColumnVector.from(oneRowTable, groupingRefs.map(_.dataType).toArray)
}
}
queue.add(keyBatch)
}
}
}

// Python input batch
val pyInputColumns = pyInputRefs.indices.safeMap { idx =>
grouped.column(idx + groupingRefs.size).asInstanceOf[GpuColumnVector].incRefCount()
}
new ColumnarBatch(pyInputColumns.toArray, groupedBatch.numRows())
val batchProducer = new BatchProducer(
BatchGroupedIterator(miniIter, miniAttrs, groupingRefs.indices))
val queue = new BatchQueue(batchProducer, Some(keyConverter))
val pyInputIter = batchProducer.asIterator.map { case (batch, isForPeek) =>
val inputBatch = closeOnExcept(batch) { _ =>
val pyInputColumns = pyInputRefs.indices.safeMap { idx =>
batch.column(idx + groupingRefs.size).asInstanceOf[GpuColumnVector].incRefCount()
}
new ColumnarBatch(pyInputColumns.toArray, batch.numRows())
}
if (isForPeek) {
batch.close()
} else {
// When adding batch to the queue, queue will convert it to a key batch because this
// queue is constructed with the key converter.
queue.add(batch)
}
inputBatch
}

// Third, sends to Python to execute the aggregate and returns the result.
Expand All @@ -223,8 +225,7 @@ case class GpuAggregateInPandasExec(
pythonRunnerConf,
// The whole group data should be written in a single call, so here is unlimited
Int.MaxValue,
DataTypeUtilsShim.fromAttributes(pyOutAttributes),
() => queue.finish())
DataTypeUtilsShim.fromAttributes(pyOutAttributes))

val pyOutputIterator = pyRunner.compute(pyInputIter, context.partitionId(), context)

Expand Down
Loading

0 comments on commit c6496bd

Please sign in to comment.