Skip to content

Commit

Permalink
[GLUTEN-8216][CH] Fix OOM when cartesian product with empty data (#8219)
Browse files Browse the repository at this point in the history
* [GLUTEN-8216][CH] Fix OOM when cartesian product with empty data

* [GLUTEN-8216][CH] Fix CI
  • Loading branch information
lwz9103 authored Dec 15, 2024
1 parent b1211a8 commit 498efc7
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,31 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V](
arrayNode
}

private def buildRangeBoundsJson(jsonMapper: ObjectMapper, arrayNode: ArrayNode): Unit = {
private def buildRangeBoundsJson(jsonMapper: ObjectMapper, arrayNode: ArrayNode): Int = {
val bounds = getRangeBounds
bounds.foreach {
bound =>
val row = bound.asInstanceOf[UnsafeRow]
arrayNode.add(buildRangeBoundJson(row, ordering, jsonMapper))
}
bounds.length
}

// Make a json structure that can be passed to native engine
def getRangeBoundsJsonString: String = {
def getRangeBoundsJsonString: RangeBoundsInfo = {
val context = new SubstraitContext()
val mapper = new ObjectMapper
val rootNode = mapper.createObjectNode
val orderingArray = rootNode.putArray("ordering")
buildOrderingJson(context, ordering, inputAttributes, mapper, orderingArray)
val boundArray = rootNode.putArray("range_bounds")
buildRangeBoundsJson(mapper, boundArray)
mapper.writeValueAsString(rootNode)
val boundLength = buildRangeBoundsJson(mapper, boundArray)
RangeBoundsInfo(mapper.writeValueAsString(rootNode), boundLength)
}
}

case class RangeBoundsInfo(json: String, boundsSize: Int)

object RangePartitionerBoundsGenerator {
def supportedFieldType(dataType: DataType): Boolean = {
dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ object CHExecUtil extends Logging {
rddForSampling,
sortingExpressions,
childOutputAttributes)
val orderingAndRangeBounds = generator.getRangeBoundsJsonString
val rangeBoundsInfo = generator.getRangeBoundsJsonString
val attributePos = if (projectOutputAttributes != null) {
projectOutputAttributes.map(
attr =>
Expand All @@ -324,10 +324,11 @@ object CHExecUtil extends Logging {
}
new NativePartitioning(
GlutenShuffleUtils.RangePartitioningShortName,
numPartitions,
rangeBoundsInfo.boundsSize + 1,
Array.empty[Byte],
orderingAndRangeBounds.getBytes(),
attributePos.mkString(",").getBytes)
rangeBoundsInfo.json.getBytes,
attributePos.mkString(",").getBytes
)
case p =>
throw new IllegalStateException(s"Unknow partition type: ${p.getClass.toString}")
}
Expand Down Expand Up @@ -368,7 +369,7 @@ object CHExecUtil extends Logging {
val dependency =
new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
rddWithPartitionKey,
new PartitionIdPassthrough(newPartitioning.numPartitions),
new PartitionIdPassthrough(nativePartitioning.getNumPartitions),
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
nativePartitioning = nativePartitioning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
val coalescedPartitionSpec0 = colCustomShuffleReaderExecs.head.partitionSpecs.head
.asInstanceOf[CoalescedPartitionSpec]
assert(coalescedPartitionSpec0.startReducerIndex == 0)
assert(coalescedPartitionSpec0.endReducerIndex == 5)
assert(coalescedPartitionSpec0.endReducerIndex == 4)
val coalescedPartitionSpec1 = colCustomShuffleReaderExecs(1).partitionSpecs.head
.asInstanceOf[CoalescedPartitionSpec]
assert(coalescedPartitionSpec1.startReducerIndex == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import org.apache.gluten.backendsapi.clickhouse.CHConf
import org.apache.gluten.utils.UTSystemParameters

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig

import java.util.concurrent.atomic.AtomicInteger

class GlutenClickHouseJoinSuite extends GlutenClickHouseWholeStageTransformerSuite {

protected val tablesPath: String = basePath + "/tpch-data"
Expand Down Expand Up @@ -141,4 +144,38 @@ class GlutenClickHouseJoinSuite extends GlutenClickHouseWholeStageTransformerSui
sql("drop table if exists tj2")
}

test("GLUTEN-8216 Fix OOM when cartesian product with empty data") {
// prepare
spark.sql("create table test_join(a int, b int, c int) using parquet")
var overrideConfs = Map(
"spark.sql.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.shuffle.partitions" -> "1"
)
if (isSparkVersionGE("3.5")) {
// Range partitions will not be reduced if EliminateSorts is enabled in spark35.
overrideConfs += "spark.sql.optimizer.excludedRules" ->
"org.apache.spark.sql.catalyst.optimizer.EliminateSorts"
}

withSQLConf(overrideConfs.toSeq: _*) {
val taskCount = new AtomicInteger(0)
val taskListener = new SparkListener {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskCount.incrementAndGet()
logDebug(s"Task ${taskEnd.taskInfo.id} finished. Total tasks completed: $taskCount")
}
}
spark.sparkContext.addSparkListener(taskListener)
spark
.sql(
"select * from " +
"(select a from test_join group by a order by a), " +
"(select b from test_join group by b order by b)" +
" limit 10000"
)
.collect()
assert(taskCount.get() < 500)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GlutenClickHouseRSSColumnarShuffleAQESuite
.partitionSpecs(0)
.asInstanceOf[CoalescedPartitionSpec]
assert(coalescedPartitionSpec0.startReducerIndex == 0)
assert(coalescedPartitionSpec0.endReducerIndex == 5)
assert(coalescedPartitionSpec0.endReducerIndex == 4)
val coalescedPartitionSpec1 = colCustomShuffleReaderExecs(1)
.partitionSpecs(0)
.asInstanceOf[CoalescedPartitionSpec]
Expand Down

0 comments on commit 498efc7

Please sign in to comment.