Skip to content

Commit

Permalink
Added in fallback tests (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jun 15, 2020
1 parent 2c2883d commit 472fdf8
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 96 deletions.
28 changes: 27 additions & 1 deletion integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit
from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm
from datetime import date, datetime
import math
from pyspark.sql import Row
Expand Down Expand Up @@ -235,6 +235,32 @@ def assert_gpu_and_cpu_writes_are_equal_iterator(write_func, read_func, base_pat
"""
_assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, False, conf=conf)

def assert_gpu_fallback_collect(func,
cpu_fallback_class_name,
conf={}):
(bring_back, collect_type) = _prep_func_for_compare(func, True)
conf = _prep_incompat_conf(conf)

print('### CPU RUN ###')
cpu_start = time.time()
from_cpu = with_cpu_session(bring_back, conf=conf)
cpu_end = time.time()
print('### GPU RUN ###')
jvm = spark_jvm()
jvm.ai.rapids.spark.ExecutionPlanCaptureCallback.startCapture()
gpu_start = time.time()
from_gpu = with_gpu_session(bring_back,
conf=conf)
gpu_end = time.time()
jvm.ai.rapids.spark.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 2000)
print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,
gpu_end - gpu_start, cpu_end - cpu_start))
if should_sort_locally():
from_cpu.sort(key=_RowCmp)
from_gpu.sort(key=_RowCmp)

assert_equal(from_cpu, from_gpu)

def _assert_gpu_and_cpu_are_equal(func,
should_collect,
conf={}):
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def _get_jvm_session(spark):
def _get_jvm(spark):
return spark.sparkContext._jvm

def spark_jvm():
return _get_jvm(spark)

class TpchRunner:
def __init__(self, tpch_format, tpch_path):
self.tpch_format = tpch_format
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect
from datetime import datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -166,8 +166,9 @@ def test_csv_fallback(spark_tmp_path, read_func, disable_conf):
reader = read_func(data_path, schema, False, ',')
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(data_path))
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
conf={disable_conf: 'false'})

csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_collect
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -52,8 +52,9 @@ def test_orc_fallback(spark_tmp_path, read_func, disable_conf):
reader = read_func(data_path)
with_cpu_session(
lambda spark : gen_df(spark, gen).write.orc(data_path))
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
conf={disable_conf: 'false'})

@pytest.mark.parametrize('orc_gens', orc_gens_list, ids=idfn)
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_collect
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -56,8 +56,9 @@ def test_parquet_fallback(spark_tmp_path, read_func, disable_conf):
reader = read_func(data_path)
with_cpu_session(
lambda spark : gen_df(spark, gen).write.parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
conf={disable_conf: 'false'})

parquet_compress_options = ['none', 'uncompressed', 'snappy', 'gzip']
Expand Down
1 change: 1 addition & 0 deletions integration_tests/src/main/python/spark_init_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _spark__init():
# can be reset in the middle of a test if specific operations are done (some types of cast etc)
_s = SparkSession.builder \
.config('spark.plugins', 'ai.rapids.spark.SQLPlugin') \
.config('spark.sql.queryExecutionListeners', 'ai.rapids.spark.ExecutionPlanCaptureCallback')\
.appName('rapids spark plugin integration tests (python)').getOrCreate()
#TODO catch the ClassNotFound error that happens if the classpath is not set up properly and
# make it a better error message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ class JoinsSuite extends SparkQueryCompareTestSuite {
conf = new SparkConf()
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.sql.join.preferSortMergeJoin", "true")
.set("spark.sql.shuffle.partitions", "2"), // hack to try and work around bug in cudf
.set("spark.sql.shuffle.partitions", "2"),
incompat = true,
sort = true,
execsAllowedNonGpu = Seq("SortExec", "SortOrder")) {
sort = true) {
(A, B) => A.join(B, A("longs") === B("longs"))
}

Expand Down
11 changes: 1 addition & 10 deletions shuffle-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>rapids-4-spark-sql_${scala.binary.version}</artifactId>
Expand Down Expand Up @@ -113,7 +108,7 @@
</execution>
</executions>
</plugin>
<!-- disable surefire as we are using scalatest only -->
<!-- disable surefire as tests are some place else -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
Expand All @@ -125,10 +120,6 @@
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.rat</groupId>
<artifactId>apache-rat-plugin</artifactId>
Expand Down
84 changes: 81 additions & 3 deletions sql-plugin/src/main/scala/ai/rapids/spark/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,27 @@
package ai.rapids.spark

import java.util
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf._
import ai.rapids.spark.RapidsPluginImplicits._
import org.apache.commons.lang3.mutable.MutableLong

import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.TaskCompletionListener

trait GpuPartitioning extends Partitioning {

Expand Down Expand Up @@ -255,6 +255,84 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
}
}

object ExecutionPlanCaptureCallback {
private[this] val shouldCapture: AtomicBoolean = new AtomicBoolean(false)
private[this] val execPlan: AtomicReference[SparkPlan] = new AtomicReference[SparkPlan]()

private def captureIfNeeded(qe: QueryExecution): Unit = {
if (shouldCapture.get()) {
execPlan.set(qe.executedPlan)
}
}

def startCapture(): Unit = {
execPlan.set(null)
shouldCapture.set(true)
}

def getResultWithTimeout(timeoutMs: Long = 2000): Option[SparkPlan] = {
try {
val endTime = System.currentTimeMillis() + timeoutMs
var plan = execPlan.getAndSet(null)
while (plan == null) {
if (System.currentTimeMillis() > endTime) {
return None
}
Thread.sleep(10)
plan = execPlan.getAndSet(null)
}
Some(plan)
} finally {
shouldCapture.set(false)
execPlan.set(null)
}
}

def assertCapturedAndGpuFellBack(fallbackCpuClass: String, timeoutMs: Long = 2000): Unit = {
val gpuPlan = getResultWithTimeout(timeoutMs=timeoutMs)
assert(gpuPlan.isDefined, "Did not capture a GPU plan")
assertDidFallBack(gpuPlan.get, fallbackCpuClass)
}

def assertDidFallBack(gpuPlan: SparkPlan, fallbackCpuClass: String): Unit = {
assert(gpuPlan.find(didFallBack(_, fallbackCpuClass)).isDefined,
s"Could not find $fallbackCpuClass in the GPU plan\n$gpuPlan")
}

private def getBaseNameFromClass(planClassStr: String): String = {
val firstDotIndex = planClassStr.lastIndexOf(".")
if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr
}

private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = {
if (!exp.isInstanceOf[GpuExpression] &&
getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass) {
true
} else {
exp.children.exists(didFallBack(_, fallbackCpuClass))
}
}

private def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = {
if (!plan.isInstanceOf[GpuExec] &&
getBaseNameFromClass(plan.getClass.getName) == fallbackCpuClass) {
true
} else {
plan.expressions.exists(didFallBack(_, fallbackCpuClass))
}
}
}

class ExecutionPlanCaptureCallback extends QueryExecutionListener {
import ExecutionPlanCaptureCallback._

override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit =
captureIfNeeded(qe)

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit =
captureIfNeeded(qe)
}

/**
* The RAPIDS plugin for Spark.
* To enable this plugin, set the config "spark.plugins" to ai.rapids.spark.SQLPlugin
Expand Down
18 changes: 0 additions & 18 deletions tests/src/test/scala/ai/rapids/spark/AnsiCastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,24 +605,6 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
spark.sql(s"SELECT c0 FROM $t3")
}

/** Test that a transformation is not supported on GPU */
private def testNotSupportedOnGpu(testName: String, frame: SparkSession => DataFrame,
sparkConf: SparkConf)(transformation: DataFrame => DataFrame): Unit = {

test(testName) {
try {
withGpuSparkSession(spark => {
val input = frame(spark).repartition(1)
transformation(input).collect()
}, sparkConf)
fail("should not run on GPU")
} catch {
case e: IllegalArgumentException =>
assert(e.getMessage.startsWith("Part of the plan is not columnar"))
}
}
}

/**
* Perform a transformation that is expected to fail due to values not being valid for
* an ansi_cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
count("*"))
}

testSparkResultsAreEqual("Agg expression with filter fall back", longsFromCSVDf,
execsAllowedNonGpu = Seq("HashAggregateExec", "AggregateExpression", "AttributeReference",
"Alias", "Literal", "Count", "GreaterThan")) {
testSparkResultsAreEqual("Agg expression with filter", longsFromCSVDf) {
frame => frame.selectExpr("count(1) filter (where longs > 20)")
}

Expand Down
Loading

0 comments on commit 472fdf8

Please sign in to comment.