Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional tests for broadcast hash join #313

Merged
merged 5 commits into from
Jul 15, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package com.nvidia.spark.rapids

import org.apache.spark.SparkConf
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.rapids.execution.GpuBroadcastHashJoinExec

class JoinsSuite extends SparkQueryCompareTestSuite {

Expand Down Expand Up @@ -97,4 +99,45 @@ class JoinsSuite extends SparkQueryCompareTestSuite {
mixedDfWithNulls, mixedDfWithNulls, sortBeforeRepart = true) {
(A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti")
}

test("broadcast hint isn't propagated after a join") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these integration tests? They don't fit with that model, and look much more like "unit" tests. Could you please move them to the tests directory instead of integration_tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @revans2 for reviewing. My bad, since all the join tests were in this file, I added these here as well not realizing these do not qualify here. I have moved these tests to a new file under tests directory. Please take a look and let me know if it's okay to keep there or should be put in any existing file.

val conf = new SparkConf()
.set("spark.sql.autoBroadcastJoinThreshold", "-1")

withGpuSparkSession(spark => {
val df1 = longsDf(spark)
val df2 = nonZeroLongsDf(spark)

val df3 = df1.join(broadcast(df2), Seq("longs"), "inner").drop(df2("longs"))
val df4 = longsDf(spark)
val df5 = df4.join(df3, Seq("longs"), "inner")

val plan = df5.queryExecution.executedPlan

assert(plan.collect { case p: GpuBroadcastHashJoinExec => p }.size === 1)
assert(plan.collect { case p: GpuShuffledHashJoinExec => p }.size === 1)
}, conf)
}

test("broadcast hint in SQL") {
withGpuSparkSession(spark => {
longsDf(spark).createOrReplaceTempView("t")
longsDf(spark).createOrReplaceTempView("u")

for (name <- Seq("BROADCAST", "BROADCASTJOIN", "MAPJOIN")) {
val plan1 = spark.sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.longs = u.longs")
.queryExecution.executedPlan
val plan2 = spark.sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.longs = u.longs")
.queryExecution.executedPlan

val res1 = plan1.find(_.isInstanceOf[GpuBroadcastHashJoinExec])
val res2 = plan2.find(_.isInstanceOf[GpuBroadcastHashJoinExec])

assert(res1.get.asInstanceOf[GpuBroadcastHashJoinExec].buildSide.toString
.equals("BuildLeft"))
assert(res2.get.asInstanceOf[GpuBroadcastHashJoinExec].buildSide.toString
.equals("BuildRight"))
}
})
}
}