Skip to content

Commit

Permalink
Fix AnsiCastOpSuite failures with Spark 3.2 (#3377)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Sep 14, 2021
1 parent f5e3d0d commit 62854cc
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids.shims.spark320
import java.net.URI
import java.nio.ByteBuffer

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2._
import com.nvidia.spark.rapids.spark320.RapidsShuffleManager
Expand All @@ -35,11 +37,12 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, ElementAt, Expression, ExprId, GetArrayItem, GetMapValue, Lag, Lead, NamedExpression, NullOrdering, PlanExpression, PythonUDF, RegExpReplace, ScalaUDF, SortDirection, SortOrder}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.CommandResult
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.{CommandResultExec, FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
Expand Down Expand Up @@ -803,4 +806,29 @@ class Spark320Shims extends Spark32XShims {
override def hasAliasQuoteFix: Boolean = true

override def hasCastFloatTimestampUpcast: Boolean = true

override def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = {
if (predicate(plan)) {
accum += plan
}
plan match {
case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum)
case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum)
case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum)
case c: CommandResultExec => recurse(c.commandPhysicalPlan, predicate, accum)
case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption
}
accum
}
recurse(plan, predicate, new ListBuffer[SparkPlan]())
}

override def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean = plan match {
case _: CommandResultExec => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus
Expand All @@ -25,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, CustomShuffleReaderExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
Expand Down Expand Up @@ -92,5 +94,26 @@ trait Spark30XShims extends SparkShims {
ss.sparkContext.defaultParallelism
}

override def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = {
if (predicate(plan)) {
accum += plan
}
plan match {
case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum)
case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum)
case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum)
case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption
}
accum
}
recurse(plan, predicate, new ListBuffer[SparkPlan]())
}

override def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean = false

override def shouldFailDivOverflow(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus
Expand All @@ -25,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, CustomShuffleReaderExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
Expand Down Expand Up @@ -97,6 +99,27 @@ trait Spark30XShims extends SparkShims {
TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all),
(exec, conf, p, r) => new GpuCustomShuffleReaderMeta(exec, conf, p, r))

override def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = {
if (predicate(plan)) {
accum += plan
}
plan match {
case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum)
case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum)
case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum)
case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption
}
accum
}
recurse(plan, predicate, new ListBuffer[SparkPlan]())
}

override def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean = false

override def shouldFailDivOverflow(): Boolean = false

override def leafNodeDefaultParallelism(ss: SparkSession): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
}
case _: ExecutedCommandExec => () // Ignored
case _: RDDScanExec => () // Ignored
case p if ShimLoader.getSparkShims.skipAssertIsOnTheGpu(p) => () // Ignored
case _ =>
if (!plan.supportsColumnar &&
// There are some python execs that are not columnar because of a little
Expand Down
12 changes: 12 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ trait SparkShims {

def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan]

/**
* Walk the plan recursively and return a list of operators that match the predicate
*/
def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan]

/**
* Our tests, by default, will check that all operators are running on the GPU, but
* there are some operators that we do not translate to GPU plans, so we need a way
* to bypass the check for those.
*/
def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean

def leafNodeDefaultParallelism(ss: SparkSession): Int
}

Expand Down
37 changes: 21 additions & 16 deletions tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -713,23 +713,28 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {


private def assertContainsAnsiCast(df: DataFrame, expected: Int = 1): DataFrame = {
var count = 0
df.queryExecution.sparkPlan.foreach {
case p: ProjectExec => count += p.projectList.count {
// ansiEnabled is protected so we rely on CastBase.toString
case c: CastBase => c.toString().startsWith("ansi_cast")
case Alias(c: CastBase, _) => c.toString().startsWith("ansi_cast")
case _ => false
}
case p: GpuProjectExec => count += p.projectList.count {
case c: GpuCast => c.ansiMode
case _ => false
}
case _ =>
}
val projections = ShimLoader.getSparkShims.findOperators(df.queryExecution.executedPlan, {
case _: ProjectExec | _: GpuProjectExec => true
case _ => false
})
val count = projections.map {
case p: ProjectExec => p.projectList.count {
// ansiEnabled is protected so we rely on CastBase.toString
case c: CastBase => c.toString().startsWith("ansi_cast")
case Alias(c: CastBase, _) => c.toString().startsWith("ansi_cast")
case _ => false
}
case p: GpuProjectExec => p.projectList.count {
case c: GpuCast => c.ansiMode
case GpuAlias(c: GpuCast, _) => c.ansiMode
case _ => false
}
case _ => 0
}.sum

if (count != expected) {
throw new IllegalStateException("Plan does not contain the expected number of " +
"ansi_cast expressions")
throw new IllegalStateException(s"Expected $expected " +
s"ansi_cast expressions, found $count")
}
df
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,13 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite {
try {
spark.sql("CREATE TABLE t(id STRING) USING PARQUET")
val df = spark.sql("INSERT INTO TABLE t SELECT 'abc'")
val insert = df.queryExecution.executedPlan.find(_.isInstanceOf[GpuDataWritingCommandExec])
assert(insert.isDefined)
assert(insert.get.metrics.contains(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME))
assert(insert.get.metrics.contains(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME))
assert(insert.get.metrics(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME).value > 0)
assert(insert.get.metrics(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME).value > 0)
val insert = ShimLoader.getSparkShims.findOperators(df.queryExecution.executedPlan,
_.isInstanceOf[GpuDataWritingCommandExec]).head
.asInstanceOf[GpuDataWritingCommandExec]
assert(insert.metrics.contains(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME))
assert(insert.metrics.contains(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME))
assert(insert.metrics(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME).value > 0)
assert(insert.metrics(BasicColumnarWriteJobStatsTracker.TASK_COMMIT_TIME).value > 0)
} finally {
spark.sql("DROP TABLE IF EXISTS tempmetricstable")
}
Expand Down
8 changes: 1 addition & 7 deletions tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,7 @@ object TestUtils extends Assertions with Arm {

/** Recursively check if the predicate matches in the given plan */
def findOperator(plan: SparkPlan, predicate: SparkPlan => Boolean): Option[SparkPlan] = {
plan match {
case _ if predicate(plan) => Some(plan)
case a: AdaptiveSparkPlanExec => findOperator(a.executedPlan, predicate)
case qs: BroadcastQueryStageExec => findOperator(qs.broadcast, predicate)
case qs: ShuffleQueryStageExec => findOperator(qs.shuffle, predicate)
case other => other.children.flatMap(p => findOperator(p, predicate)).headOption
}
ShimLoader.getSparkShims.findOperators(plan, predicate).headOption
}

/** Return final executed plan */
Expand Down

0 comments on commit 62854cc

Please sign in to comment.