diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala index 1861a9f2515..72ed0e79504 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,14 +41,14 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ThreadUtils -class GpuSubqueryBroadcastMeta( +abstract class GpuSubqueryBroadcastMetaBase( s: SubqueryBroadcastExec, conf: RapidsConf, p: Option[RapidsMeta[_, _, _]], r: DataFromReplacementRule) extends SparkPlanMeta[SubqueryBroadcastExec](s, conf, p, r) { - private var broadcastBuilder: () => SparkPlan = _ + protected var broadcastBuilder: () => SparkPlan = _ override val childExprs: Seq[BaseExprMeta[_]] = Nil @@ -140,13 +140,8 @@ class GpuSubqueryBroadcastMeta( */ override def convertToCpu(): SparkPlan = s - override def convertToGpu(): GpuExec = { - GpuSubqueryBroadcastExec(s.name, s.index, s.buildKeys, broadcastBuilder())( - getBroadcastModeKeyExprs) - } - /** Extract the broadcast mode key expressions if there are any. */ - private def getBroadcastModeKeyExprs: Option[Seq[Expression]] = { + protected def getBroadcastModeKeyExprs: Option[Seq[Expression]] = { val broadcastMode = s.child match { case b: BroadcastExchangeExec => b.mode @@ -170,7 +165,7 @@ class GpuSubqueryBroadcastMeta( case class GpuSubqueryBroadcastExec( name: String, - index: Int, + indices: Seq[Int], buildKeys: Seq[Expression], child: SparkPlan)(modeKeys: Option[Seq[Expression]]) extends ShimBaseSubqueryExec with GpuExec with ShimUnaryExecNode { @@ -182,16 +177,18 @@ case class GpuSubqueryBroadcastExec( // correctly report the output length, so that `InSubqueryExec` can know it's the single-column // execution mode, not multi-column. override def output: Seq[Attribute] = { - val key = buildKeys(index) - val name = key match { - case n: NamedExpression => - n.name - case cast: Cast if cast.child.isInstanceOf[NamedExpression] => - cast.child.asInstanceOf[NamedExpression].name - case _ => - "key" + indices.map { index => + val key = buildKeys(index) + val name = key match { + case n: NamedExpression => + n.name + case cast: Cast if cast.child.isInstanceOf[NamedExpression] => + cast.child.asInstanceOf[NamedExpression].name + case _ => + "key" + } + AttributeReference(name, key.dataType, key.nullable)() } - Seq(AttributeReference(name, key.dataType, key.nullable)()) } override lazy val additionalMetrics: Map[String, GpuMetric] = Map( @@ -200,7 +197,7 @@ case class GpuSubqueryBroadcastExec( override def doCanonicalize(): SparkPlan = { val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output)) - GpuSubqueryBroadcastExec("dpp", index, keys, child.canonicalized)(modeKeys) + GpuSubqueryBroadcastExec("dpp", indices, keys, child.canonicalized)(modeKeys) } @transient @@ -235,28 +232,30 @@ case class GpuSubqueryBroadcastExec( // are being extracted. The CPU already has the key projections applied in the broadcast // data and thus does not have similar logic here. val broadcastModeProject = modeKeys.map { keyExprs => - val keyExpr = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) { + val exprs = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) { // in this case, there is only 1 key expression since it's a packed version that encompasses // multiple integral values into a single long using bit logic. In CPU Spark, the broadcast // would create a LongHashedRelation instead of a standard HashedRelation. - keyExprs.head + indices.map { _ => keyExprs.head } } else { - keyExprs(index) + indices.map { idx => keyExprs(idx) } } - UnsafeProjection.create(keyExpr) + UnsafeProjection.create(exprs) } - // Use the single output of the broadcast mode projection if it exists - val rowProjectIndex = if (broadcastModeProject.isDefined) 0 else index - val rowExpr = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) { + val rowExprs = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) { // Since this is the expected output for a LongHashedRelation, we can extract the key from the - // long packed key using bit logic, using this method available in HashJoin to give us the - // correct key expression. - HashJoin.extractKeyExprAt(buildKeys, index) + // long packed key using bit logic, using this method available in HashJoin to give us the + // correct key expression. + indices.map { idx => HashJoin.extractKeyExprAt(buildKeys, idx) } } else { - BoundReference(rowProjectIndex, buildKeys(index).dataType, buildKeys(index).nullable) + indices.map { idx => + // Use the single output of the broadcast mode projection if it exists + val rowProjectIndex = if (broadcastModeProject.isDefined) 0 else idx + BoundReference(rowProjectIndex, buildKeys(idx).dataType, buildKeys(idx).nullable) + } } - val rowProject = UnsafeProjection.create(rowExpr) + val rowProject = UnsafeProjection.create(rowExprs) // Deserializes the batch on the host. Then, transforms it to rows and performs row-wise // projection. We should NOT run any device operation on the driver node. diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastMeta.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastMeta.scala new file mode 100644 index 00000000000..9bcfa33ab87 --- /dev/null +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastMeta.scala @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "350"} +{"spark": "351"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution + +import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuExec, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.execution.SubqueryBroadcastExec + +class GpuSubqueryBroadcastMeta( + s: SubqueryBroadcastExec, + conf: RapidsConf, + p: Option[RapidsMeta[_, _, _]], + r: DataFromReplacementRule) extends + GpuSubqueryBroadcastMetaBase(s, conf, p, r) { + override def convertToGpu(): GpuExec = { + GpuSubqueryBroadcastExec(s.name, Seq(s.index), s.buildKeys, broadcastBuilder())( + getBroadcastModeKeyExprs) + } +} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastMeta.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastMeta.scala new file mode 100644 index 00000000000..c16564f523e --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastMeta.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution + +import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuExec, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.execution.SubqueryBroadcastExec + +class GpuSubqueryBroadcastMeta( + s: SubqueryBroadcastExec, + conf: RapidsConf, + p: Option[RapidsMeta[_, _, _]], + r: DataFromReplacementRule) extends + GpuSubqueryBroadcastMetaBase(s, conf, p, r) { + override def convertToGpu(): GpuExec = { + GpuSubqueryBroadcastExec(s.name, s.indices, s.buildKeys, broadcastBuilder())( + getBroadcastModeKeyExprs) + } +} diff --git a/tests/src/test/spark321/scala/com/nvidia/spark/rapids/DynamicPruningSuite.scala b/tests/src/test/spark321/scala/com/nvidia/spark/rapids/DynamicPruningSuite.scala index 2d4156d1b3b..722e5bb215b 100644 --- a/tests/src/test/spark321/scala/com/nvidia/spark/rapids/DynamicPruningSuite.scala +++ b/tests/src/test/spark321/scala/com/nvidia/spark/rapids/DynamicPruningSuite.scala @@ -66,7 +66,7 @@ class DynamicPruningSuite // NOTE: We remove the AdaptiveSparkPlanExec since we can't re-run the new plan // under AQE because that fundamentally requires some rewrite and stage // ordering which we can't do for this test. - case GpuSubqueryBroadcastExec(name, index, buildKeys, child) => + case GpuSubqueryBroadcastExec(name, Seq(index), buildKeys, child) => val newChild = child match { case a @ AdaptiveSparkPlanExec(_, _, _, _, _) => (new GpuTransitionOverrides()).apply(ColumnarToRowExec(a.executedPlan))