Skip to content

Commit

Permalink
Left overlap join function (#578)
Browse files Browse the repository at this point in the history
* left range join implementation and tests

Signed-off-by: Henry Davidge <[email protected]>

* add tests and docs

Signed-off-by: Henry Davidge <[email protected]>

* add headers

Signed-off-by: Henry Davidge <[email protected]>

* Rename to overlap join

Signed-off-by: Henry Davidge <[email protected]>

* fix doc example

Signed-off-by: Henry Davidge <[email protected]>

* add left semi join

Signed-off-by: Henry Davidge <[email protected]>

* fix test

Signed-off-by: Henry Davidge <[email protected]>

* ignore genome asia

Signed-off-by: Henry Davidge <[email protected]>

* Setting version to 2.0.0

* Setting stable version to 2.0.0

* fill docs a bit more

Signed-off-by: Henry Davidge <[email protected]>

---------

Signed-off-by: Henry Davidge <[email protected]>
Co-authored-by: Henry Davidge <[email protected]>
  • Loading branch information
henrydavidge and Henry Davidge authored Feb 16, 2024
1 parent 347edeb commit 4577cce
Show file tree
Hide file tree
Showing 10 changed files with 671 additions and 2 deletions.
104 changes: 104 additions & 0 deletions core/src/main/scala/io/projectglow/sql/LeftOverlapJoin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.projectglow.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, DataFrame}

object LeftOverlapJoin {

private def maybePrefixRightColumns(
table: DataFrame,
leftColumns: Seq[Column],
rightColumns: Seq[Column],
prefixOpt: Option[String]): DataFrame = prefixOpt match {
case Some(prefix) =>
val renamedRightCols = rightColumns.map(c => c.alias(s"${prefix}$c"))
table.select((leftColumns ++ renamedRightCols): _*)
case None => table
}

/**
* Executes a left outer join with an interval overlap condition accelerated
* by Databricks' range join optimization <https://docs.databricks.com/en/optimizations/range-join.html>.
* This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not.
*
* @param extraJoinExpr If provided, this expression will be included in the join criteria
* @param rightPrefix If provided, all columns from the right table will begin have their names prefixed with
* this value in the joined table
* @param binSize The bin size for the range join optimization. Consult the Databricks documentation for more info.
*/
def leftJoin(
left: DataFrame,
right: DataFrame,
leftStart: Column,
rightStart: Column,
leftEnd: Column,
rightEnd: Column,
extraJoinExpr: Column = lit(true),
rightPrefix: Option[String] = None,
binSize: Int = 5000): DataFrame = {
val rightPrepared = right.hint("range_join", binSize)
val rangeExpr = leftStart < rightEnd && rightStart < leftEnd
val leftPoints = left.where(leftEnd - leftStart === 1)
val leftIntervals = left.where(leftEnd - leftStart =!= 1)
val pointsJoined = leftPoints.join(
rightPrepared,
leftStart >= rightStart && leftStart < rightEnd && extraJoinExpr,
joinType = "left")
val longVarsInner = leftIntervals.join(rightPrepared, rangeExpr && extraJoinExpr)
val result = leftIntervals
.join(longVarsInner, leftIntervals.columns, joinType = "left")
.union(pointsJoined)
maybePrefixRightColumns(
result,
left.columns.map(left.apply),
rightPrepared.columns.map(rightPrepared.apply),
rightPrefix)
}

/**
* Executes a left semi join with an interval overlap condition accelerated
* by Databricks' range join optimization <https://docs.databricks.com/en/optimizations/range-join.html>.
* This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not.
*
* @param extraJoinExpr If provided, this expression will be included in the join criteria
* @param binSize The bin size for the range join optimization. Consult the Databricks documentation for more info.
*/
def leftSemiJoin(
left: DataFrame,
right: DataFrame,
leftStart: Column,
rightStart: Column,
leftEnd: Column,
rightEnd: Column,
extraJoinExpr: Column = lit(true),
binSize: Int = 5000): DataFrame = {
val rightPrepared = right.hint("range_join", binSize)
val rangeExpr = leftStart < rightEnd && rightStart < leftEnd
val leftPoints = left.where(leftEnd - leftStart === 1)
val leftIntervals = left.where(leftEnd - leftStart =!= 1)
val pointsJoined = leftPoints.join(
rightPrepared,
leftStart >= rightStart && leftStart < rightEnd && extraJoinExpr,
joinType = "left_semi")
val longVarsInner = leftIntervals.join(rightPrepared, rangeExpr && extraJoinExpr)
val longVarsLeftSemi =
longVarsInner.select(leftIntervals.columns.map(c => leftIntervals(c)): _*).dropDuplicates()
longVarsLeftSemi.union(pointsJoined)
}
}
286 changes: 286 additions & 0 deletions core/src/test/scala/io/projectglow/sql/LeftOverlapJoinSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.projectglow.sql

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions._

abstract class OverlapJoinSuite extends GlowBaseTest {
case class Interval(start: Long, end: Long)
case class StringWithInterval(name: String, start: Long, end: Long)
protected lazy val sess = spark
protected def isSemi: Boolean
protected def naiveLeftJoin(
left: DataFrame,
right: DataFrame,
leftStart: Column,
rightStart: Column,
leftEnd: Column,
rightEnd: Column,
extraJoinExpr: Column = lit(true)): DataFrame = {
left.join(
right,
leftStart < rightEnd && rightStart < leftEnd && extraJoinExpr,
joinType = if (isSemi) "left_semi" else "left")
}

private def compareToNaive(
left: DataFrame,
right: DataFrame,
extraJoinExpr: Column = lit(true)): Unit = {
val (leftStart, rightStart) = (left("start"), right("start"))
val (leftEnd, rightEnd) = (left("end"), right("end"))
val glow = if (isSemi) {
LeftOverlapJoin.leftSemiJoin(
left,
right,
leftStart,
rightStart,
leftEnd,
rightEnd,
extraJoinExpr)
} else {
LeftOverlapJoin.leftJoin(left, right, leftStart, rightStart, leftEnd, rightEnd, extraJoinExpr)
}
val naive = naiveLeftJoin(left, right, leftStart, rightStart, leftEnd, rightEnd, extraJoinExpr)
assert(glow.count() == naive.count() && glow.except(naive).count() == 0)
}

test("Simple long intervals") {
val left = spark.createDataFrame(
Seq(
Interval(1, 7),
Interval(3, 10)
))
val right = spark.createDataFrame(
Seq(
Interval(1, 4),
Interval(2, 5)
))
compareToNaive(left, right)
}

test("Unmatched left intervals") {
val left = spark.createDataFrame(
Seq(
Interval(1, 3),
Interval(7, 10)
))
val right = spark.createDataFrame(
Seq(
Interval(2, 5)
))
compareToNaive(left, right)
}

test("Unmatched right intervals") {
val left = spark.createDataFrame(
Seq(
Interval(2, 5)
))
val right = spark.createDataFrame(
Seq(
Interval(1, 3),
Interval(7, 10)
))
compareToNaive(left, right)
}

test("Points and intervals") {
val left = spark.createDataFrame(
Seq(
Interval(2, 3),
Interval(1, 3),
Interval(7, 10),
Interval(8, 9)
))
val right = spark.createDataFrame(
Seq(
Interval(2, 5)
))
compareToNaive(left, right)
}

test("extraJoinExpr") {
val left = spark.createDataFrame(
Seq(
StringWithInterval("a", 1, 7),
StringWithInterval("a", 3, 10),
StringWithInterval("c", 2, 3)
))
val right = spark.createDataFrame(
Seq(
StringWithInterval("a", 1, 4),
StringWithInterval("b", 2, 5)
))
compareToNaive(left, right, left("name") === right("name"))
}

test("table aliases") {
val left = spark.createDataFrame(
Seq(
StringWithInterval("a", 1, 7),
StringWithInterval("a", 3, 10),
StringWithInterval("c", 2, 3)
))
val right = spark.createDataFrame(
Seq(
StringWithInterval("a", 1, 4),
StringWithInterval("b", 2, 5)
))
compareToNaive(left.alias("left"), right.alias("right"))
}

test("handles negative intervals") {
val left = spark.createDataFrame(
Seq(
StringWithInterval("b", 2, 7),
StringWithInterval("b", 3, 2)
))
val right = spark.createDataFrame(
Seq(
StringWithInterval("a", 3, 4)
))
compareToNaive(left, right)
}

test("Ranges that touch at a point should not join") {
val left = spark.createDataFrame(Seq(Interval(0, 10)))
val right = spark.createDataFrame(Seq(Interval(10, 20)))
compareToNaive(left, right)
}

test("Ranges that touch at a point should not join (point)") {
val left = spark.createDataFrame(Seq(Interval(9, 10)))
val right = spark.createDataFrame(Seq(Interval(10, 20)))
compareToNaive(left, right)
}

test("Fully contained ranges") {
val left = spark.createDataFrame(Seq(Interval(0, 10)))
val right = spark.createDataFrame(Seq(Interval(2, 5)))

compareToNaive(left, right)
}

test("Identical start and end points") {
val left = spark.createDataFrame(Seq(Interval(0, 10)))
val right = spark.createDataFrame(Seq(Interval(0, 10)))
compareToNaive(left, right)
}

test("Identical start points (point)") {
val left = spark.createDataFrame(Seq(Interval(0, 1)))
val right = spark.createDataFrame(Seq(Interval(0, 10)))
compareToNaive(left, right)
}
}

class LeftOverlapJoinSuite extends OverlapJoinSuite {
override def isSemi: Boolean = false
test("naive implementation") {
import sess.implicits._
val left = Seq(
("a", 1, 10), // matched
("a", 2, 3), // matched
("a", 5, 7), // unmatched
("a", 2, 5), // matched
("b", 1, 10) // unmatched
).toDF("name", "start", "end")
val right = Seq(
("a", 2, 5), // matched
("c", 2, 5) // unmatched
).toDF("name", "start", "end")
val joined = naiveLeftJoin(
left,
right,
left("start"),
right("start"),
left("end"),
right("end"),
left("name") === right("name"))
assert(joined.count() == 5) // All five left rows are present
assert(joined.where(right("start").isNull).count() == 2) // Unmatched left rows have no right fields
// Unmatched right rows are not present
assert(
joined
.where(right("name") === "c" && right("start") === 2 && right("end") === 5)
.count() == 0)
}

test("duplicate column names (prefix)") {
val left = spark
.createDataFrame(
Seq(
StringWithInterval("a", 1, 7),
StringWithInterval("a", 3, 10),
StringWithInterval("c", 2, 3)
))
.alias("left")
val right = spark
.createDataFrame(
Seq(
StringWithInterval("a", 1, 4),
StringWithInterval("b", 2, 5)
))
.alias("right")
val joined = LeftOverlapJoin.leftJoin(
left,
right,
left("start"),
right("start"),
left("end"),
right("end"),
rightPrefix = Some("right_"))
right.columns.foreach { c =>
assert(joined.columns.contains(s"right_$c"))
}
withTempDir { f =>
val tablePath = f.toPath.resolve("joined")
joined.write.parquet(tablePath.toString)
spark.read.parquet(tablePath.toString).collect() // Can read and write table
}
}
}

class LeftSemiOverlapJoinSuite extends OverlapJoinSuite {
override def isSemi: Boolean = true
test("naive implementation (semi)") {
import sess.implicits._
val left = Seq(
("a", 1, 10), // matched
("a", 2, 3), // matched
("a", 5, 7), // unmatched
("a", 2, 5), // matched
("b", 1, 10) // unmatched
).toDF("name", "start", "end")
val right = Seq(
("a", 2, 5), // matched
("c", 2, 5) // unmatched
).toDF("name", "start", "end")
val joined = naiveLeftJoin(
left,
right,
left("start"),
right("start"),
left("end"),
right("end"),
left("name") === right("name"))
assert(
joined.as[(String, Int, Int)].collect().toSet == Set(("a", 1, 10), ("a", 2, 3), ("a", 2, 5)))
}
}
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
# Captcha required
r'https://gatk.broadinstitute.org*',
# Intermittently read timeouts
r'https://genomeasia100k.org*',
r'http://ftp.1000genomes.ebi.ac.uk*',
]

Expand Down
Loading

0 comments on commit 4577cce

Please sign in to comment.