diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index 445aac99990..c309a3efa8a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -64,9 +64,6 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { ProjectExec(exprs, c2r) }.getOrElse(c2r) p.withNewChildren(Array(newChild)) - case exec: GpuShuffleExchangeExecBase => - addPostShuffleCoalesce( - exec.withNewChildren(Seq(optimizeGpuPlanTransitions(exec.child)))) case p => p.withNewChildren(p.children.map(optimizeGpuPlanTransitions)) } @@ -515,6 +512,24 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { p.withNewChildren(p.children.map(c => insertCoalesce(c, shouldDisable))) } + /** + * Inserts a shuffle coalesce after every shuffle to coalesce the serialized tables + * on the host before copying the data to the GPU. + * @note This should not be used in combination with the RAPIDS shuffle. + */ + private def insertShuffleCoalesce(plan: SparkPlan): SparkPlan = plan match { + case exec: GpuShuffleExchangeExecBase => + // always follow a GPU shuffle with a shuffle coalesce + if (GpuShuffleEnv.serializingOnGpu(rapidsConf)) { + GpuCoalesceBatches(exec.withNewChildren(exec.children.map(insertShuffleCoalesce)), + TargetSize(rapidsConf.gpuTargetBatchSizeBytes)) + } else { + GpuShuffleCoalesceExec(exec.withNewChildren(exec.children.map(insertShuffleCoalesce)), + rapidsConf.gpuTargetBatchSizeBytes) + } + case exec => exec.withNewChildren(plan.children.map(insertShuffleCoalesce)) + } + /** * Inserts a transition to be running on the CPU columnar */ @@ -787,6 +802,10 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } updatedPlan = insertColumnarFromGpu(updatedPlan) updatedPlan = insertCoalesce(updatedPlan) + // only insert shuffle coalesces when using normal shuffle + if (!GpuShuffleEnv.useGPUShuffle(rapidsConf)) { + updatedPlan = insertShuffleCoalesce(updatedPlan) + } if (plan.conf.adaptiveExecutionEnabled) { updatedPlan = optimizeAdaptiveTransitions(updatedPlan, None) } else {