diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index dc6b658abd6..2757d95cc7e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -64,13 +64,16 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { ProjectExec(exprs, c2r) }.getOrElse(c2r) p.withNewChildren(Array(newChild)) + case exec: GpuShuffleExchangeExecBase => + addPostShuffleCoalesce( + exec.withNewChildren(Seq(optimizeGpuPlanTransitions(exec.child)))) case p => p.withNewChildren(p.children.map(optimizeGpuPlanTransitions)) } /** Adds the appropriate coalesce after a shuffle depending on the type of shuffle configured */ private def addPostShuffleCoalesce(plan: SparkPlan): SparkPlan = { - if (GpuShuffleEnv.useGPUShuffle(rapidsConf)) { + if (GpuShuffleEnv.useGPUShuffle(rapidsConf) || GpuShuffleEnv.serializingOnGpu(rapidsConf)) { GpuCoalesceBatches(plan, TargetSize(rapidsConf.gpuTargetBatchSizeBytes)) } else { GpuShuffleCoalesceExec(plan, rapidsConf.gpuTargetBatchSizeBytes) @@ -511,19 +514,6 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { p.withNewChildren(p.children.map(c => insertCoalesce(c, shouldDisable))) } - /** - * Inserts a shuffle coalesce after every shuffle to coalesce the serialized tables - * on the host before copying the data to the GPU. - * @note This should not be used in combination with the RAPIDS shuffle. - */ - private def insertShuffleCoalesce(plan: SparkPlan): SparkPlan = plan match { - case exec: GpuShuffleExchangeExecBase => - // always follow a GPU shuffle with a shuffle coalesce - GpuShuffleCoalesceExec(exec.withNewChildren(exec.children.map(insertShuffleCoalesce)), - rapidsConf.gpuTargetBatchSizeBytes) - case exec => exec.withNewChildren(plan.children.map(insertShuffleCoalesce)) - } - /** * Inserts a transition to be running on the CPU columnar */ @@ -796,10 +786,6 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } updatedPlan = insertColumnarFromGpu(updatedPlan) updatedPlan = insertCoalesce(updatedPlan) - // only insert shuffle coalesces when using normal shuffle - if (!GpuShuffleEnv.useGPUShuffle(rapidsConf)) { - updatedPlan = insertShuffleCoalesce(updatedPlan) - } if (plan.conf.adaptiveExecutionEnabled) { updatedPlan = optimizeAdaptiveTransitions(updatedPlan, None) } else {