Skip to content

Commit

Permalink
Fix a Pandas UDF slowness issue (#11395)
Browse files Browse the repository at this point in the history
Close #10770

In CombiningIterator, the call to hasNext of pythonOutputIter may trigger a read without setting the target rows
number, and the default rows number is Int.MaxValue, then the GpuArrowReader will try to read in a quite big
batch when the partition data is big enough, leading to too much data copying by
DirectByteBufferOutputStream at the writer side. Then slowness comes up.

This PR changes the default read rows number to arrowMaxRecordsPerBatch to align with the Arrow
 batching behavior in Spark, and set the target read rows number in the hasNext function too.

---------

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman authored Aug 30, 2024
1 parent dbd92d2 commit db1d580
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -403,23 +403,37 @@ class CombiningIterator(
private var pendingInput: Option[SpillableColumnarBatch] = None
Option(TaskContext.get()).foreach(onTaskCompletion(_)(pendingInput.foreach(_.close())))

// The Python output should line up row for row so we only look at the Python output
// iterator and no need to check the `inputPending` who will be consumed when draining
// the Python output.
override def hasNext: Boolean = pythonOutputIter.hasNext
private var nextReadRowsNum: Option[Int] = None

override def next(): ColumnarBatch = {
private def initRowsNumForNextRead(): Unit = if (nextReadRowsNum.isEmpty){
val numRows = inputBatchQueue.peekBatchNumRows()
// Updates the expected batch size for next read
pythonArrowReader.setMinReadTargetNumRows(numRows)
nextReadRowsNum = Some(numRows)
}

// The Python output should line up row for row so we only look at the Python output
// iterator and no need to check the `pendingInput` who will be consumed when draining
// the Python output.
override def hasNext: Boolean = {
// pythonOutputIter.hasNext may trigger a read, so init the read rows number here.
initRowsNumForNextRead()
pythonOutputIter.hasNext
}

override def next(): ColumnarBatch = {
initRowsNumForNextRead()
// Reads next batch from Python and combines it with the input batch by the left side.
withResource(pythonOutputIter.next()) { cbFromPython =>
// nextReadRowsNum should be set here after a read.
val nextRowsNum = nextReadRowsNum.get
nextReadRowsNum = None
// Here may get a batch has a larger rows number than the current input batch.
assert(cbFromPython.numRows() >= numRows,
s"Expects >=$numRows rows but got ${cbFromPython.numRows()} from the Python worker")
assert(cbFromPython.numRows() >= nextRowsNum,
s"Expects >=$nextRowsNum rows but got ${cbFromPython.numRows()} from the Python worker")
withResource(concatInputBatch(cbFromPython.numRows())) { concated =>
numOutputBatches += 1
numOutputRows += numRows
numOutputRows += concated.numRows()
GpuColumnVector.combineColumns(concated, cbFromPython)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.GpuSemaphore

import org.apache.spark.TaskContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch


Expand Down Expand Up @@ -62,10 +63,11 @@ trait GpuArrowOutput {
protected def toBatch(table: Table): ColumnarBatch

/**
* Default to `Int.MaxValue` to try to read as many as possible.
* Default to minimum one between "arrowMaxRecordsPerBatch" and 10000.
* Change it by calling `setMinReadTargetNumRows` before a reading.
*/
private var minReadTargetNumRows: Int = Int.MaxValue
private var minReadTargetNumRows: Int = math.min(
SQLConf.get.arrowMaxRecordsPerBatch, 10000)

def newGpuArrowReader: GpuArrowReader = new GpuArrowReader

Expand Down

0 comments on commit db1d580

Please sign in to comment.