Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AnsiCastOpSuite failures with Spark 3.2 #3377

Merged
merged 13 commits into from
Sep 14, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids.shims.spark320
import java.net.URI
import java.nio.ByteBuffer

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2._
import com.nvidia.spark.rapids.spark320.RapidsShuffleManager
Expand All @@ -35,11 +37,12 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, ElementAt, Expression, ExprId, GetArrayItem, GetMapValue, Lag, Lead, NamedExpression, NullOrdering, PlanExpression, PythonUDF, RegExpReplace, ScalaUDF, SortDirection, SortOrder}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.CommandResult
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.{CommandResultExec, FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
Expand Down Expand Up @@ -807,4 +810,29 @@ class Spark320Shims extends Spark32XShims {
override def hasAliasQuoteFix: Boolean = true

override def hasCastFloatTimestampUpcast: Boolean = true

override def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = {
if (predicate(plan)) {
accum += plan
}
plan match {
case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum)
case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum)
case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum)
case c: CommandResultExec => recurse(c.commandPhysicalPlan, predicate, accum)
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption
}
accum
}
recurse(plan, predicate, new ListBuffer[SparkPlan]())
}

override def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean = plan match {
case _: CommandResultExec => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus
Expand All @@ -25,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, CustomShuffleReaderExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
Expand Down Expand Up @@ -97,6 +99,27 @@ trait Spark30XShims extends SparkShims {
TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all),
(exec, conf, p, r) => new GpuCustomShuffleReaderMeta(exec, conf, p, r))

override def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = {
if (predicate(plan)) {
accum += plan
}
plan match {
case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum)
case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum)
case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum)
case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption
}
accum
}
recurse(plan, predicate, new ListBuffer[SparkPlan]())
}

override def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean = false

override def leafNodeDefaultParallelism(ss: SparkSession): Int = {
ss.sparkContext.defaultParallelism
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
}
case _: ExecutedCommandExec => () // Ignored
case _: RDDScanExec => () // Ignored
case p if ShimLoader.getSparkShims.skipAssertIsOnTheGpu(p) => () // Ignored
case _ =>
if (!plan.supportsColumnar &&
// There are some python execs that are not columnar because of a little
Expand Down
12 changes: 12 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ trait SparkShims {

def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan]

/**
* Walk the plan recursively and return a list of operators that match the predicate
*/
def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan]

/**
* Our tests, by default, will check that all operators are running on the GPU, but
* there are some operators that we do not translate to GPU plans, so we need a way
* to bypass the check for those.
*/
def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean

def leafNodeDefaultParallelism(ss: SparkSession): Int
}

Expand Down
37 changes: 21 additions & 16 deletions tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -713,23 +713,28 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {


private def assertContainsAnsiCast(df: DataFrame, expected: Int = 1): DataFrame = {
var count = 0
df.queryExecution.sparkPlan.foreach {
case p: ProjectExec => count += p.projectList.count {
// ansiEnabled is protected so we rely on CastBase.toString
case c: CastBase => c.toString().startsWith("ansi_cast")
case Alias(c: CastBase, _) => c.toString().startsWith("ansi_cast")
case _ => false
}
case p: GpuProjectExec => count += p.projectList.count {
case c: GpuCast => c.ansiMode
case _ => false
}
case _ =>
}
val projections = ShimLoader.getSparkShims.findOperators(df.queryExecution.executedPlan, {
case _: ProjectExec | _: GpuProjectExec => true
case _ => false
})
val count = projections.map {
case p: ProjectExec => p.projectList.count {
// ansiEnabled is protected so we rely on CastBase.toString
case c: CastBase => c.toString().startsWith("ansi_cast")
case Alias(c: CastBase, _) => c.toString().startsWith("ansi_cast")
case _ => false
}
case p: GpuProjectExec => p.projectList.count {
case c: GpuCast => c.ansiMode
case GpuAlias(c: GpuCast, _) => c.ansiMode
case _ => false
}
case _ => 0
}.sum

if (count != expected) {
throw new IllegalStateException("Plan does not contain the expected number of " +
"ansi_cast expressions")
throw new IllegalStateException(s"Expected $expected " +
s"ansi_cast expressions, found $count")
}
df
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,13 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite {
try {
spark.sql("CREATE TABLE t(id STRING) USING PARQUET")
val df = spark.sql("INSERT INTO TABLE t SELECT 'abc'")
val insert = df.queryExecution.executedPlan.find(_.isInstanceOf[GpuDataWritingCommandExec])
assert(insert.isDefined)
assert(insert.get.metrics.contains(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME))
assert(insert.get.metrics.contains(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME))
assert(insert.get.metrics(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME).value > 0)
assert(insert.get.metrics(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME).value > 0)
val insert = ShimLoader.getSparkShims.findOperators(df.queryExecution.executedPlan,
_.isInstanceOf[GpuDataWritingCommandExec]).head
.asInstanceOf[GpuDataWritingCommandExec]
assert(insert.metrics.contains(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME))
assert(insert.metrics.contains(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME))
assert(insert.metrics(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME).value > 0)
assert(insert.metrics(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME).value > 0)
} finally {
spark.sql("DROP TABLE IF EXISTS tempmetricstable")
}
Expand Down
8 changes: 1 addition & 7 deletions tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,7 @@ object TestUtils extends Assertions with Arm {

/** Recursively check if the predicate matches in the given plan */
def findOperator(plan: SparkPlan, predicate: SparkPlan => Boolean): Option[SparkPlan] = {
plan match {
case _ if predicate(plan) => Some(plan)
case a: AdaptiveSparkPlanExec => findOperator(a.executedPlan, predicate)
case qs: BroadcastQueryStageExec => findOperator(qs.broadcast, predicate)
case qs: ShuffleQueryStageExec => findOperator(qs.shuffle, predicate)
case other => other.children.flatMap(p => findOperator(p, predicate)).headOption
}
ShimLoader.getSparkShims.findOperators(plan, predicate).headOption
}

/** Return final executed plan */
Expand Down