-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* left range join implementation and tests Signed-off-by: Henry Davidge <[email protected]> * add tests and docs Signed-off-by: Henry Davidge <[email protected]> * add headers Signed-off-by: Henry Davidge <[email protected]> * Rename to overlap join Signed-off-by: Henry Davidge <[email protected]> * fix doc example Signed-off-by: Henry Davidge <[email protected]> * add left semi join Signed-off-by: Henry Davidge <[email protected]> * fix test Signed-off-by: Henry Davidge <[email protected]> * ignore genome asia Signed-off-by: Henry Davidge <[email protected]> * Setting version to 2.0.0 * Setting stable version to 2.0.0 * fill docs a bit more Signed-off-by: Henry Davidge <[email protected]> --------- Signed-off-by: Henry Davidge <[email protected]> Co-authored-by: Henry Davidge <[email protected]>
- Loading branch information
1 parent
347edeb
commit 4577cce
Showing
10 changed files
with
671 additions
and
2 deletions.
There are no files selected for viewing
104 changes: 104 additions & 0 deletions
104
core/src/main/scala/io/projectglow/sql/LeftOverlapJoin.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
* Copyright 2019 The Glow Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.projectglow.sql | ||
|
||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.{Column, DataFrame} | ||
|
||
object LeftOverlapJoin { | ||
|
||
private def maybePrefixRightColumns( | ||
table: DataFrame, | ||
leftColumns: Seq[Column], | ||
rightColumns: Seq[Column], | ||
prefixOpt: Option[String]): DataFrame = prefixOpt match { | ||
case Some(prefix) => | ||
val renamedRightCols = rightColumns.map(c => c.alias(s"${prefix}$c")) | ||
table.select((leftColumns ++ renamedRightCols): _*) | ||
case None => table | ||
} | ||
|
||
/** | ||
* Executes a left outer join with an interval overlap condition accelerated | ||
* by Databricks' range join optimization <https://docs.databricks.com/en/optimizations/range-join.html>. | ||
* This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not. | ||
* | ||
* @param extraJoinExpr If provided, this expression will be included in the join criteria | ||
* @param rightPrefix If provided, all columns from the right table will begin have their names prefixed with | ||
* this value in the joined table | ||
* @param binSize The bin size for the range join optimization. Consult the Databricks documentation for more info. | ||
*/ | ||
def leftJoin( | ||
left: DataFrame, | ||
right: DataFrame, | ||
leftStart: Column, | ||
rightStart: Column, | ||
leftEnd: Column, | ||
rightEnd: Column, | ||
extraJoinExpr: Column = lit(true), | ||
rightPrefix: Option[String] = None, | ||
binSize: Int = 5000): DataFrame = { | ||
val rightPrepared = right.hint("range_join", binSize) | ||
val rangeExpr = leftStart < rightEnd && rightStart < leftEnd | ||
val leftPoints = left.where(leftEnd - leftStart === 1) | ||
val leftIntervals = left.where(leftEnd - leftStart =!= 1) | ||
val pointsJoined = leftPoints.join( | ||
rightPrepared, | ||
leftStart >= rightStart && leftStart < rightEnd && extraJoinExpr, | ||
joinType = "left") | ||
val longVarsInner = leftIntervals.join(rightPrepared, rangeExpr && extraJoinExpr) | ||
val result = leftIntervals | ||
.join(longVarsInner, leftIntervals.columns, joinType = "left") | ||
.union(pointsJoined) | ||
maybePrefixRightColumns( | ||
result, | ||
left.columns.map(left.apply), | ||
rightPrepared.columns.map(rightPrepared.apply), | ||
rightPrefix) | ||
} | ||
|
||
/** | ||
* Executes a left semi join with an interval overlap condition accelerated | ||
* by Databricks' range join optimization <https://docs.databricks.com/en/optimizations/range-join.html>. | ||
* This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not. | ||
* | ||
* @param extraJoinExpr If provided, this expression will be included in the join criteria | ||
* @param binSize The bin size for the range join optimization. Consult the Databricks documentation for more info. | ||
*/ | ||
def leftSemiJoin( | ||
left: DataFrame, | ||
right: DataFrame, | ||
leftStart: Column, | ||
rightStart: Column, | ||
leftEnd: Column, | ||
rightEnd: Column, | ||
extraJoinExpr: Column = lit(true), | ||
binSize: Int = 5000): DataFrame = { | ||
val rightPrepared = right.hint("range_join", binSize) | ||
val rangeExpr = leftStart < rightEnd && rightStart < leftEnd | ||
val leftPoints = left.where(leftEnd - leftStart === 1) | ||
val leftIntervals = left.where(leftEnd - leftStart =!= 1) | ||
val pointsJoined = leftPoints.join( | ||
rightPrepared, | ||
leftStart >= rightStart && leftStart < rightEnd && extraJoinExpr, | ||
joinType = "left_semi") | ||
val longVarsInner = leftIntervals.join(rightPrepared, rangeExpr && extraJoinExpr) | ||
val longVarsLeftSemi = | ||
longVarsInner.select(leftIntervals.columns.map(c => leftIntervals(c)): _*).dropDuplicates() | ||
longVarsLeftSemi.union(pointsJoined) | ||
} | ||
} |
286 changes: 286 additions & 0 deletions
286
core/src/test/scala/io/projectglow/sql/LeftOverlapJoinSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
/* | ||
* Copyright 2019 The Glow Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.projectglow.sql | ||
|
||
import org.apache.spark.sql.{Column, DataFrame} | ||
import org.apache.spark.sql.functions._ | ||
|
||
abstract class OverlapJoinSuite extends GlowBaseTest { | ||
case class Interval(start: Long, end: Long) | ||
case class StringWithInterval(name: String, start: Long, end: Long) | ||
protected lazy val sess = spark | ||
protected def isSemi: Boolean | ||
protected def naiveLeftJoin( | ||
left: DataFrame, | ||
right: DataFrame, | ||
leftStart: Column, | ||
rightStart: Column, | ||
leftEnd: Column, | ||
rightEnd: Column, | ||
extraJoinExpr: Column = lit(true)): DataFrame = { | ||
left.join( | ||
right, | ||
leftStart < rightEnd && rightStart < leftEnd && extraJoinExpr, | ||
joinType = if (isSemi) "left_semi" else "left") | ||
} | ||
|
||
private def compareToNaive( | ||
left: DataFrame, | ||
right: DataFrame, | ||
extraJoinExpr: Column = lit(true)): Unit = { | ||
val (leftStart, rightStart) = (left("start"), right("start")) | ||
val (leftEnd, rightEnd) = (left("end"), right("end")) | ||
val glow = if (isSemi) { | ||
LeftOverlapJoin.leftSemiJoin( | ||
left, | ||
right, | ||
leftStart, | ||
rightStart, | ||
leftEnd, | ||
rightEnd, | ||
extraJoinExpr) | ||
} else { | ||
LeftOverlapJoin.leftJoin(left, right, leftStart, rightStart, leftEnd, rightEnd, extraJoinExpr) | ||
} | ||
val naive = naiveLeftJoin(left, right, leftStart, rightStart, leftEnd, rightEnd, extraJoinExpr) | ||
assert(glow.count() == naive.count() && glow.except(naive).count() == 0) | ||
} | ||
|
||
test("Simple long intervals") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
Interval(1, 7), | ||
Interval(3, 10) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
Interval(1, 4), | ||
Interval(2, 5) | ||
)) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Unmatched left intervals") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
Interval(1, 3), | ||
Interval(7, 10) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
Interval(2, 5) | ||
)) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Unmatched right intervals") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
Interval(2, 5) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
Interval(1, 3), | ||
Interval(7, 10) | ||
)) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Points and intervals") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
Interval(2, 3), | ||
Interval(1, 3), | ||
Interval(7, 10), | ||
Interval(8, 9) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
Interval(2, 5) | ||
)) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("extraJoinExpr") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 1, 7), | ||
StringWithInterval("a", 3, 10), | ||
StringWithInterval("c", 2, 3) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 1, 4), | ||
StringWithInterval("b", 2, 5) | ||
)) | ||
compareToNaive(left, right, left("name") === right("name")) | ||
} | ||
|
||
test("table aliases") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 1, 7), | ||
StringWithInterval("a", 3, 10), | ||
StringWithInterval("c", 2, 3) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 1, 4), | ||
StringWithInterval("b", 2, 5) | ||
)) | ||
compareToNaive(left.alias("left"), right.alias("right")) | ||
} | ||
|
||
test("handles negative intervals") { | ||
val left = spark.createDataFrame( | ||
Seq( | ||
StringWithInterval("b", 2, 7), | ||
StringWithInterval("b", 3, 2) | ||
)) | ||
val right = spark.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 3, 4) | ||
)) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Ranges that touch at a point should not join") { | ||
val left = spark.createDataFrame(Seq(Interval(0, 10))) | ||
val right = spark.createDataFrame(Seq(Interval(10, 20))) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Ranges that touch at a point should not join (point)") { | ||
val left = spark.createDataFrame(Seq(Interval(9, 10))) | ||
val right = spark.createDataFrame(Seq(Interval(10, 20))) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Fully contained ranges") { | ||
val left = spark.createDataFrame(Seq(Interval(0, 10))) | ||
val right = spark.createDataFrame(Seq(Interval(2, 5))) | ||
|
||
compareToNaive(left, right) | ||
} | ||
|
||
test("Identical start and end points") { | ||
val left = spark.createDataFrame(Seq(Interval(0, 10))) | ||
val right = spark.createDataFrame(Seq(Interval(0, 10))) | ||
compareToNaive(left, right) | ||
} | ||
|
||
test("Identical start points (point)") { | ||
val left = spark.createDataFrame(Seq(Interval(0, 1))) | ||
val right = spark.createDataFrame(Seq(Interval(0, 10))) | ||
compareToNaive(left, right) | ||
} | ||
} | ||
|
||
class LeftOverlapJoinSuite extends OverlapJoinSuite { | ||
override def isSemi: Boolean = false | ||
test("naive implementation") { | ||
import sess.implicits._ | ||
val left = Seq( | ||
("a", 1, 10), // matched | ||
("a", 2, 3), // matched | ||
("a", 5, 7), // unmatched | ||
("a", 2, 5), // matched | ||
("b", 1, 10) // unmatched | ||
).toDF("name", "start", "end") | ||
val right = Seq( | ||
("a", 2, 5), // matched | ||
("c", 2, 5) // unmatched | ||
).toDF("name", "start", "end") | ||
val joined = naiveLeftJoin( | ||
left, | ||
right, | ||
left("start"), | ||
right("start"), | ||
left("end"), | ||
right("end"), | ||
left("name") === right("name")) | ||
assert(joined.count() == 5) // All five left rows are present | ||
assert(joined.where(right("start").isNull).count() == 2) // Unmatched left rows have no right fields | ||
// Unmatched right rows are not present | ||
assert( | ||
joined | ||
.where(right("name") === "c" && right("start") === 2 && right("end") === 5) | ||
.count() == 0) | ||
} | ||
|
||
test("duplicate column names (prefix)") { | ||
val left = spark | ||
.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 1, 7), | ||
StringWithInterval("a", 3, 10), | ||
StringWithInterval("c", 2, 3) | ||
)) | ||
.alias("left") | ||
val right = spark | ||
.createDataFrame( | ||
Seq( | ||
StringWithInterval("a", 1, 4), | ||
StringWithInterval("b", 2, 5) | ||
)) | ||
.alias("right") | ||
val joined = LeftOverlapJoin.leftJoin( | ||
left, | ||
right, | ||
left("start"), | ||
right("start"), | ||
left("end"), | ||
right("end"), | ||
rightPrefix = Some("right_")) | ||
right.columns.foreach { c => | ||
assert(joined.columns.contains(s"right_$c")) | ||
} | ||
withTempDir { f => | ||
val tablePath = f.toPath.resolve("joined") | ||
joined.write.parquet(tablePath.toString) | ||
spark.read.parquet(tablePath.toString).collect() // Can read and write table | ||
} | ||
} | ||
} | ||
|
||
class LeftSemiOverlapJoinSuite extends OverlapJoinSuite { | ||
override def isSemi: Boolean = true | ||
test("naive implementation (semi)") { | ||
import sess.implicits._ | ||
val left = Seq( | ||
("a", 1, 10), // matched | ||
("a", 2, 3), // matched | ||
("a", 5, 7), // unmatched | ||
("a", 2, 5), // matched | ||
("b", 1, 10) // unmatched | ||
).toDF("name", "start", "end") | ||
val right = Seq( | ||
("a", 2, 5), // matched | ||
("c", 2, 5) // unmatched | ||
).toDF("name", "start", "end") | ||
val joined = naiveLeftJoin( | ||
left, | ||
right, | ||
left("start"), | ||
right("start"), | ||
left("end"), | ||
right("end"), | ||
left("name") === right("name")) | ||
assert( | ||
joined.as[(String, Int, Int)].collect().toSet == Set(("a", 1, 10), ("a", 2, 3), ("a", 2, 5))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.