Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added in fallback tests #174

Merged
merged 4 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit
from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm
from datetime import date, datetime
import math
from pyspark.sql import Row
Expand Down Expand Up @@ -235,6 +235,32 @@ def assert_gpu_and_cpu_writes_are_equal_iterator(write_func, read_func, base_pat
"""
_assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, False, conf=conf)

def assert_gpu_fallback_collect(func,
cpu_fallback_class_name,
conf={}):
kuhushukla marked this conversation as resolved.
Show resolved Hide resolved
(bring_back, collect_type) = _prep_func_for_compare(func, True)
conf = _prep_incompat_conf(conf)

print('### CPU RUN ###')
cpu_start = time.time()
from_cpu = with_cpu_session(bring_back, conf=conf)
cpu_end = time.time()
print('### GPU RUN ###')
jvm = spark_jvm()
jvm.ai.rapids.spark.ExecutionPlanCaptureCallback.startCapture()
gpu_start = time.time()
from_gpu = with_gpu_session(bring_back,
conf=conf)
gpu_end = time.time()
jvm.ai.rapids.spark.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 2000)
print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,
gpu_end - gpu_start, cpu_end - cpu_start))
if should_sort_locally():
from_cpu.sort(key=_RowCmp)
from_gpu.sort(key=_RowCmp)

assert_equal(from_cpu, from_gpu)

def _assert_gpu_and_cpu_are_equal(func,
should_collect,
conf={}):
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def _get_jvm_session(spark):
def _get_jvm(spark):
return spark.sparkContext._jvm

def spark_jvm():
return _get_jvm(spark)

class TpchRunner:
def __init__(self, tpch_format, tpch_path):
self.tpch_format = tpch_format
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect
from datetime import datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -166,8 +166,9 @@ def test_csv_fallback(spark_tmp_path, read_func, disable_conf):
reader = read_func(data_path, schema, False, ',')
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(data_path))
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
conf={disable_conf: 'false'})

csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_collect
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -52,8 +52,9 @@ def test_orc_fallback(spark_tmp_path, read_func, disable_conf):
reader = read_func(data_path)
with_cpu_session(
lambda spark : gen_df(spark, gen).write.orc(data_path))
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
conf={disable_conf: 'false'})

@pytest.mark.parametrize('orc_gens', orc_gens_list, ids=idfn)
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_collect
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -56,8 +56,9 @@ def test_parquet_fallback(spark_tmp_path, read_func, disable_conf):
reader = read_func(data_path)
with_cpu_session(
lambda spark : gen_df(spark, gen).write.parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
conf={disable_conf: 'false'})

parquet_compress_options = ['none', 'uncompressed', 'snappy', 'gzip']
Expand Down
1 change: 1 addition & 0 deletions integration_tests/src/main/python/spark_init_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _spark__init():
# can be reset in the middle of a test if specific operations are done (some types of cast etc)
_s = SparkSession.builder \
.config('spark.plugins', 'ai.rapids.spark.SQLPlugin') \
.config('spark.sql.queryExecutionListeners', 'ai.rapids.spark.ExecutionPlanCaptureCallback')\
.appName('rapids spark plugin integration tests (python)').getOrCreate()
#TODO catch the ClassNotFound error that happens if the classpath is not set up properly and
# make it a better error message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ class JoinsSuite extends SparkQueryCompareTestSuite {
conf = new SparkConf()
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.sql.join.preferSortMergeJoin", "true")
.set("spark.sql.shuffle.partitions", "2"), // hack to try and work around bug in cudf
.set("spark.sql.shuffle.partitions", "2"),
incompat = true,
sort = true,
execsAllowedNonGpu = Seq("SortExec", "SortOrder")) {
sort = true) {
(A, B) => A.join(B, A("longs") === B("longs"))
}

Expand Down
11 changes: 1 addition & 10 deletions shuffle-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>rapids-4-spark-sql_${scala.binary.version}</artifactId>
Expand Down Expand Up @@ -113,7 +108,7 @@
</execution>
</executions>
</plugin>
<!-- disable surefire as we are using scalatest only -->
<!-- disable surefire as tests are some place else -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
Expand All @@ -125,10 +120,6 @@
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
</plugin>
kuhushukla marked this conversation as resolved.
Show resolved Hide resolved
<plugin>
<groupId>org.apache.rat</groupId>
<artifactId>apache-rat-plugin</artifactId>
Expand Down
84 changes: 81 additions & 3 deletions sql-plugin/src/main/scala/ai/rapids/spark/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,27 @@
package ai.rapids.spark

import java.util
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf._
import ai.rapids.spark.RapidsPluginImplicits._
import org.apache.commons.lang3.mutable.MutableLong

import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.TaskCompletionListener

trait GpuPartitioning extends Partitioning {

Expand Down Expand Up @@ -255,6 +255,84 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
}
}

object ExecutionPlanCaptureCallback {
private[this] val shouldCapture: AtomicBoolean = new AtomicBoolean(false)
private[this] val execPlan: AtomicReference[SparkPlan] = new AtomicReference[SparkPlan]()

private def captureIfNeeded(qe: QueryExecution): Unit = {
if (shouldCapture.get()) {
execPlan.set(qe.executedPlan)
}
}

def startCapture(): Unit = {
execPlan.set(null)
shouldCapture.set(true)
}

def getResultWithTimeout(timeoutMs: Long = 2000): Option[SparkPlan] = {
try {
val endTime = System.currentTimeMillis() + timeoutMs
var plan = execPlan.getAndSet(null)
while (plan == null) {
if (System.currentTimeMillis() > endTime) {
return None
}
Thread.sleep(10)
plan = execPlan.getAndSet(null)
}
Some(plan)
} finally {
shouldCapture.set(false)
execPlan.set(null)
}
}

def assertCapturedAndGpuFellBack(fallbackCpuClass: String, timeoutMs: Long = 2000): Unit = {
val gpuPlan = getResultWithTimeout(timeoutMs=timeoutMs)
assert(gpuPlan.isDefined, "Did not capture a GPU plan")
assertDidFallBack(gpuPlan.get, fallbackCpuClass)
}

def assertDidFallBack(gpuPlan: SparkPlan, fallbackCpuClass: String): Unit = {
assert(gpuPlan.find(didFallBack(_, fallbackCpuClass)).isDefined,
s"Could not find $fallbackCpuClass in the GPU plan\n$gpuPlan")
}

private def getBaseNameFromClass(planClassStr: String): String = {
val firstDotIndex = planClassStr.lastIndexOf(".")
if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr
}

private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = {
if (!exp.isInstanceOf[GpuExpression] &&
getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass) {
true
} else {
exp.children.exists(didFallBack(_, fallbackCpuClass))
}
}

private def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = {
if (!plan.isInstanceOf[GpuExec] &&
getBaseNameFromClass(plan.getClass.getName) == fallbackCpuClass) {
true
} else {
plan.expressions.exists(didFallBack(_, fallbackCpuClass))
}
}
}

class ExecutionPlanCaptureCallback extends QueryExecutionListener {
import ExecutionPlanCaptureCallback._

override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit =
captureIfNeeded(qe)

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit =
captureIfNeeded(qe)
}

/**
* The RAPIDS plugin for Spark.
* To enable this plugin, set the config "spark.plugins" to ai.rapids.spark.SQLPlugin
Expand Down
18 changes: 0 additions & 18 deletions tests/src/test/scala/ai/rapids/spark/AnsiCastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,24 +605,6 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
spark.sql(s"SELECT c0 FROM $t3")
}

/** Test that a transformation is not supported on GPU */
private def testNotSupportedOnGpu(testName: String, frame: SparkSession => DataFrame,
sparkConf: SparkConf)(transformation: DataFrame => DataFrame): Unit = {

test(testName) {
try {
withGpuSparkSession(spark => {
val input = frame(spark).repartition(1)
transformation(input).collect()
}, sparkConf)
fail("should not run on GPU")
} catch {
case e: IllegalArgumentException =>
assert(e.getMessage.startsWith("Part of the plan is not columnar"))
}
}
}

/**
* Perform a transformation that is expected to fail due to values not being valid for
* an ansi_cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
count("*"))
}

testSparkResultsAreEqual("Agg expression with filter fall back", longsFromCSVDf,
execsAllowedNonGpu = Seq("HashAggregateExec", "AggregateExpression", "AttributeReference",
"Alias", "Literal", "Count", "GreaterThan")) {
testSparkResultsAreEqual("Agg expression with filter", longsFromCSVDf) {
frame => frame.selectExpr("count(1) filter (where longs > 20)")
}

Expand Down
Loading