From 4577cce934ca9bb097f651dec121a81513c806b1 Mon Sep 17 00:00:00 2001 From: Henry Davidge Date: Fri, 16 Feb 2024 15:05:39 +0700 Subject: [PATCH] Left overlap join function (#578) * left range join implementation and tests Signed-off-by: Henry Davidge * add tests and docs Signed-off-by: Henry Davidge * add headers Signed-off-by: Henry Davidge * Rename to overlap join Signed-off-by: Henry Davidge * fix doc example Signed-off-by: Henry Davidge * add left semi join Signed-off-by: Henry Davidge * fix test Signed-off-by: Henry Davidge * ignore genome asia Signed-off-by: Henry Davidge * Setting version to 2.0.0 * Setting stable version to 2.0.0 * fill docs a bit more Signed-off-by: Henry Davidge --------- Signed-off-by: Henry Davidge Co-authored-by: Henry Davidge --- .../io/projectglow/sql/LeftOverlapJoin.scala | 104 +++++++ .../sql/LeftOverlapJoinSuite.scala | 286 ++++++++++++++++++ docs/source/conf.py | 1 + python/glow/__init__.py | 5 + python/glow/sql/__init__.py | 18 ++ python/glow/sql/functions.py | 154 ++++++++++ python/glow/sql/tests/__init__.py | 0 python/glow/sql/tests/test_overlap_join.py | 101 +++++++ stable-version.txt | 2 +- version.sbt | 2 +- 10 files changed, 671 insertions(+), 2 deletions(-) create mode 100644 core/src/main/scala/io/projectglow/sql/LeftOverlapJoin.scala create mode 100644 core/src/test/scala/io/projectglow/sql/LeftOverlapJoinSuite.scala create mode 100644 python/glow/sql/__init__.py create mode 100644 python/glow/sql/functions.py create mode 100644 python/glow/sql/tests/__init__.py create mode 100644 python/glow/sql/tests/test_overlap_join.py diff --git a/core/src/main/scala/io/projectglow/sql/LeftOverlapJoin.scala b/core/src/main/scala/io/projectglow/sql/LeftOverlapJoin.scala new file mode 100644 index 000000000..8aa930ced --- /dev/null +++ b/core/src/main/scala/io/projectglow/sql/LeftOverlapJoin.scala @@ -0,0 +1,104 @@ +/* + * Copyright 2019 The Glow Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.projectglow.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{Column, DataFrame} + +object LeftOverlapJoin { + + private def maybePrefixRightColumns( + table: DataFrame, + leftColumns: Seq[Column], + rightColumns: Seq[Column], + prefixOpt: Option[String]): DataFrame = prefixOpt match { + case Some(prefix) => + val renamedRightCols = rightColumns.map(c => c.alias(s"${prefix}$c")) + table.select((leftColumns ++ renamedRightCols): _*) + case None => table + } + + /** + * Executes a left outer join with an interval overlap condition accelerated + * by Databricks' range join optimization . + * This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not. + * + * @param extraJoinExpr If provided, this expression will be included in the join criteria + * @param rightPrefix If provided, all columns from the right table will begin have their names prefixed with + * this value in the joined table + * @param binSize The bin size for the range join optimization. Consult the Databricks documentation for more info. + */ + def leftJoin( + left: DataFrame, + right: DataFrame, + leftStart: Column, + rightStart: Column, + leftEnd: Column, + rightEnd: Column, + extraJoinExpr: Column = lit(true), + rightPrefix: Option[String] = None, + binSize: Int = 5000): DataFrame = { + val rightPrepared = right.hint("range_join", binSize) + val rangeExpr = leftStart < rightEnd && rightStart < leftEnd + val leftPoints = left.where(leftEnd - leftStart === 1) + val leftIntervals = left.where(leftEnd - leftStart =!= 1) + val pointsJoined = leftPoints.join( + rightPrepared, + leftStart >= rightStart && leftStart < rightEnd && extraJoinExpr, + joinType = "left") + val longVarsInner = leftIntervals.join(rightPrepared, rangeExpr && extraJoinExpr) + val result = leftIntervals + .join(longVarsInner, leftIntervals.columns, joinType = "left") + .union(pointsJoined) + maybePrefixRightColumns( + result, + left.columns.map(left.apply), + rightPrepared.columns.map(rightPrepared.apply), + rightPrefix) + } + + /** + * Executes a left semi join with an interval overlap condition accelerated + * by Databricks' range join optimization . + * This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not. + * + * @param extraJoinExpr If provided, this expression will be included in the join criteria + * @param binSize The bin size for the range join optimization. Consult the Databricks documentation for more info. + */ + def leftSemiJoin( + left: DataFrame, + right: DataFrame, + leftStart: Column, + rightStart: Column, + leftEnd: Column, + rightEnd: Column, + extraJoinExpr: Column = lit(true), + binSize: Int = 5000): DataFrame = { + val rightPrepared = right.hint("range_join", binSize) + val rangeExpr = leftStart < rightEnd && rightStart < leftEnd + val leftPoints = left.where(leftEnd - leftStart === 1) + val leftIntervals = left.where(leftEnd - leftStart =!= 1) + val pointsJoined = leftPoints.join( + rightPrepared, + leftStart >= rightStart && leftStart < rightEnd && extraJoinExpr, + joinType = "left_semi") + val longVarsInner = leftIntervals.join(rightPrepared, rangeExpr && extraJoinExpr) + val longVarsLeftSemi = + longVarsInner.select(leftIntervals.columns.map(c => leftIntervals(c)): _*).dropDuplicates() + longVarsLeftSemi.union(pointsJoined) + } +} diff --git a/core/src/test/scala/io/projectglow/sql/LeftOverlapJoinSuite.scala b/core/src/test/scala/io/projectglow/sql/LeftOverlapJoinSuite.scala new file mode 100644 index 000000000..55365aca8 --- /dev/null +++ b/core/src/test/scala/io/projectglow/sql/LeftOverlapJoinSuite.scala @@ -0,0 +1,286 @@ +/* + * Copyright 2019 The Glow Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.projectglow.sql + +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.functions._ + +abstract class OverlapJoinSuite extends GlowBaseTest { + case class Interval(start: Long, end: Long) + case class StringWithInterval(name: String, start: Long, end: Long) + protected lazy val sess = spark + protected def isSemi: Boolean + protected def naiveLeftJoin( + left: DataFrame, + right: DataFrame, + leftStart: Column, + rightStart: Column, + leftEnd: Column, + rightEnd: Column, + extraJoinExpr: Column = lit(true)): DataFrame = { + left.join( + right, + leftStart < rightEnd && rightStart < leftEnd && extraJoinExpr, + joinType = if (isSemi) "left_semi" else "left") + } + + private def compareToNaive( + left: DataFrame, + right: DataFrame, + extraJoinExpr: Column = lit(true)): Unit = { + val (leftStart, rightStart) = (left("start"), right("start")) + val (leftEnd, rightEnd) = (left("end"), right("end")) + val glow = if (isSemi) { + LeftOverlapJoin.leftSemiJoin( + left, + right, + leftStart, + rightStart, + leftEnd, + rightEnd, + extraJoinExpr) + } else { + LeftOverlapJoin.leftJoin(left, right, leftStart, rightStart, leftEnd, rightEnd, extraJoinExpr) + } + val naive = naiveLeftJoin(left, right, leftStart, rightStart, leftEnd, rightEnd, extraJoinExpr) + assert(glow.count() == naive.count() && glow.except(naive).count() == 0) + } + + test("Simple long intervals") { + val left = spark.createDataFrame( + Seq( + Interval(1, 7), + Interval(3, 10) + )) + val right = spark.createDataFrame( + Seq( + Interval(1, 4), + Interval(2, 5) + )) + compareToNaive(left, right) + } + + test("Unmatched left intervals") { + val left = spark.createDataFrame( + Seq( + Interval(1, 3), + Interval(7, 10) + )) + val right = spark.createDataFrame( + Seq( + Interval(2, 5) + )) + compareToNaive(left, right) + } + + test("Unmatched right intervals") { + val left = spark.createDataFrame( + Seq( + Interval(2, 5) + )) + val right = spark.createDataFrame( + Seq( + Interval(1, 3), + Interval(7, 10) + )) + compareToNaive(left, right) + } + + test("Points and intervals") { + val left = spark.createDataFrame( + Seq( + Interval(2, 3), + Interval(1, 3), + Interval(7, 10), + Interval(8, 9) + )) + val right = spark.createDataFrame( + Seq( + Interval(2, 5) + )) + compareToNaive(left, right) + } + + test("extraJoinExpr") { + val left = spark.createDataFrame( + Seq( + StringWithInterval("a", 1, 7), + StringWithInterval("a", 3, 10), + StringWithInterval("c", 2, 3) + )) + val right = spark.createDataFrame( + Seq( + StringWithInterval("a", 1, 4), + StringWithInterval("b", 2, 5) + )) + compareToNaive(left, right, left("name") === right("name")) + } + + test("table aliases") { + val left = spark.createDataFrame( + Seq( + StringWithInterval("a", 1, 7), + StringWithInterval("a", 3, 10), + StringWithInterval("c", 2, 3) + )) + val right = spark.createDataFrame( + Seq( + StringWithInterval("a", 1, 4), + StringWithInterval("b", 2, 5) + )) + compareToNaive(left.alias("left"), right.alias("right")) + } + + test("handles negative intervals") { + val left = spark.createDataFrame( + Seq( + StringWithInterval("b", 2, 7), + StringWithInterval("b", 3, 2) + )) + val right = spark.createDataFrame( + Seq( + StringWithInterval("a", 3, 4) + )) + compareToNaive(left, right) + } + + test("Ranges that touch at a point should not join") { + val left = spark.createDataFrame(Seq(Interval(0, 10))) + val right = spark.createDataFrame(Seq(Interval(10, 20))) + compareToNaive(left, right) + } + + test("Ranges that touch at a point should not join (point)") { + val left = spark.createDataFrame(Seq(Interval(9, 10))) + val right = spark.createDataFrame(Seq(Interval(10, 20))) + compareToNaive(left, right) + } + + test("Fully contained ranges") { + val left = spark.createDataFrame(Seq(Interval(0, 10))) + val right = spark.createDataFrame(Seq(Interval(2, 5))) + + compareToNaive(left, right) + } + + test("Identical start and end points") { + val left = spark.createDataFrame(Seq(Interval(0, 10))) + val right = spark.createDataFrame(Seq(Interval(0, 10))) + compareToNaive(left, right) + } + + test("Identical start points (point)") { + val left = spark.createDataFrame(Seq(Interval(0, 1))) + val right = spark.createDataFrame(Seq(Interval(0, 10))) + compareToNaive(left, right) + } +} + +class LeftOverlapJoinSuite extends OverlapJoinSuite { + override def isSemi: Boolean = false + test("naive implementation") { + import sess.implicits._ + val left = Seq( + ("a", 1, 10), // matched + ("a", 2, 3), // matched + ("a", 5, 7), // unmatched + ("a", 2, 5), // matched + ("b", 1, 10) // unmatched + ).toDF("name", "start", "end") + val right = Seq( + ("a", 2, 5), // matched + ("c", 2, 5) // unmatched + ).toDF("name", "start", "end") + val joined = naiveLeftJoin( + left, + right, + left("start"), + right("start"), + left("end"), + right("end"), + left("name") === right("name")) + assert(joined.count() == 5) // All five left rows are present + assert(joined.where(right("start").isNull).count() == 2) // Unmatched left rows have no right fields + // Unmatched right rows are not present + assert( + joined + .where(right("name") === "c" && right("start") === 2 && right("end") === 5) + .count() == 0) + } + + test("duplicate column names (prefix)") { + val left = spark + .createDataFrame( + Seq( + StringWithInterval("a", 1, 7), + StringWithInterval("a", 3, 10), + StringWithInterval("c", 2, 3) + )) + .alias("left") + val right = spark + .createDataFrame( + Seq( + StringWithInterval("a", 1, 4), + StringWithInterval("b", 2, 5) + )) + .alias("right") + val joined = LeftOverlapJoin.leftJoin( + left, + right, + left("start"), + right("start"), + left("end"), + right("end"), + rightPrefix = Some("right_")) + right.columns.foreach { c => + assert(joined.columns.contains(s"right_$c")) + } + withTempDir { f => + val tablePath = f.toPath.resolve("joined") + joined.write.parquet(tablePath.toString) + spark.read.parquet(tablePath.toString).collect() // Can read and write table + } + } +} + +class LeftSemiOverlapJoinSuite extends OverlapJoinSuite { + override def isSemi: Boolean = true + test("naive implementation (semi)") { + import sess.implicits._ + val left = Seq( + ("a", 1, 10), // matched + ("a", 2, 3), // matched + ("a", 5, 7), // unmatched + ("a", 2, 5), // matched + ("b", 1, 10) // unmatched + ).toDF("name", "start", "end") + val right = Seq( + ("a", 2, 5), // matched + ("c", 2, 5) // unmatched + ).toDF("name", "start", "end") + val joined = naiveLeftJoin( + left, + right, + left("start"), + right("start"), + left("end"), + right("end"), + left("name") === right("name")) + assert( + joined.as[(String, Int, Int)].collect().toSet == Set(("a", 1, 10), ("a", 2, 3), ("a", 2, 5))) + } +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 0dca320c8..b6363ae14 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -236,6 +236,7 @@ # Captcha required r'https://gatk.broadinstitute.org*', # Intermittently read timeouts + r'https://genomeasia100k.org*', r'http://ftp.1000genomes.ebi.ac.uk*', ] diff --git a/python/glow/__init__.py b/python/glow/__init__.py index dcfd99b4b..fa0848047 100644 --- a/python/glow/__init__.py +++ b/python/glow/__init__.py @@ -33,3 +33,8 @@ def extend_all(module): from . import wgr from . import gwas + +from .sql import * +from . import sql + +extend_all(sql) diff --git a/python/glow/sql/__init__.py b/python/glow/sql/__init__.py new file mode 100644 index 000000000..821f31c13 --- /dev/null +++ b/python/glow/sql/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2019 The Glow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .functions import * +from . import functions + +__all__ = functions.__all__ \ No newline at end of file diff --git a/python/glow/sql/functions.py b/python/glow/sql/functions.py new file mode 100644 index 000000000..e77df268b --- /dev/null +++ b/python/glow/sql/functions.py @@ -0,0 +1,154 @@ +# Copyright 2019 The Glow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from pyspark.sql.column import Column, _to_java_column +from pyspark.sql.functions import lit +from typeguard import typechecked + +__all__ = ['left_overlap_join', 'left_semi_overlap_join'] + + +def _prepare_sql_args(left: DataFrame, + right: DataFrame, + left_start: Column | None = None, + right_start: Column | None = None, + left_end: Column | None = None, + right_end: Column | None = None, + extra_join_expr: Column | None = None) -> DataFrame: + unexpected_columns_error = ValueError( + 'Explicit start and end columns must be specified if the left and right ' + + 'DataFrames do not contain columns named start and end.') + if left_start is None: + if not 'start' in left.columns: raise unexpected_columns_error + else: left_start = left.start + if right_start is None: + if not 'start' in right.columns: raise unexpected_columns_error + else: right_start = right.start + if left_end is None: + if not 'end' in left.columns: raise unexpected_columns_error + else: left_end = left.end + if right_end is None: + if not 'end' in right.columns: raise unexpected_columns_error + else: right_end = right.end + + if extra_join_expr is None and 'contigName' in left.columns and 'contigName' in right.columns: + extra_join_expr = left.contigName == right.contigName + if extra_join_expr is None: + extra_join_expr = lit(True) + + column_args = [ + _to_java_column(c) for c in [left_start, right_start, left_end, right_end, extra_join_expr] + ] + return [left._jdf, right._jdf] + column_args + + +@typechecked +def left_overlap_join(left: DataFrame, + right: DataFrame, + left_start: Column | None = None, + right_start: Column | None = None, + left_end: Column | None = None, + right_end: Column | None = None, + extra_join_expr: Column | None = None, + right_prefix: str | None = None, + bin_size: int = 5000) -> DataFrame: + """ + Executes a left outer join with an interval overlap condition accelerated + by `Databricks' range join optimization `__. + This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not. + + Args: + left: The first DataFrame to join. This DataFrame is expected to be larger and contain + a mixture of SNPs (intervals with length 1) and longer intervals. + right: The second DataFrame to join. It is expected to contain primarily longer intervals. + left_start: The interval start column in the left DataFrame. It must be specified if there is not a + column named ``start``. + left_end: The interval end column in the left DataFrame. It must be specified if there is not a + column named ``end``. + right_start: The interval start column in the right DataFrame. It must be specified if there is not a + column named ``start``. + right_end: The interval end column in the right DataFrame. It must be specified if there is not a + column named ``end``. + extra_join_expr: An expression containing additional join criteria. If a column named ``contigName`` + exists in both the left and right DataFrames, the default value is ``left.contigName == right.contigName`` + right_prefix: If provided, all columns in the joined DataFrame that originated from the right DataFrame will + have their names prefixed with this string. Can be useful if some column names are duplicated between the + left and right DataFrames. + bin_size: The bin size to use for the range join optimization + + Example: + >>> left = spark.createDataFrame([(1, 10)], ["start", "end"]) + >>> right = spark.createDataFrame([(2, 3)], ["start", "end"]) + >>> df = glow.left_overlap_join(left, right) + + Returns: + The joined DataFrame + + """ + join_fn = SparkContext._jvm.io.projectglow.sql.LeftOverlapJoin.leftJoin + fn_args = _prepare_sql_args(left, right, left_start, right_start, left_end, right_end, + extra_join_expr) + prefix_arg = SparkContext._jvm.scala.Option.apply(right_prefix) + fn_args = fn_args + [prefix_arg, bin_size] + output_jdf = join_fn(*fn_args) + return DataFrame(output_jdf, left.sparkSession) + + +@typechecked +def left_semi_overlap_join(left: DataFrame, + right: DataFrame, + left_start: Column | None = None, + right_start: Column | None = None, + left_end: Column | None = None, + right_end: Column | None = None, + extra_join_expr: Column | None = None, + bin_size: int = 5000) -> DataFrame: + """ + Executes a left semi join with an interval overlap condition accelerated + by `Databricks' range join optimization `__. + This function assumes half open intervals i.e., (0, 2) and (1, 2) overlap but (0, 2) and (2, 3) do not. + + Args: + left: The first DataFrame to join. This DataFrame is expected to be larger and contain + a mixture of SNPs (intervals with length 1) and longer intervals. + right: The second DataFrame to join. It is expected to contain primarily longer intervals. + left_start: The interval start column in the left DataFrame. It must be specified if there is not a + column named ``start``. + left_end: The interval end column in the left DataFrame. It must be specified if there is not a + column named ``end``. + right_start: The interval start column in the right DataFrame. It must be specified if there is not a + column named ``start``. + right_end: The interval end column in the right DataFrame. It must be specified if there is not a + column named ``end``. + extra_join_expr: An expression containing additional join criteria. If a column named ``contigName`` + exists in both the left and right DataFrames, the default value is ``left.contigName == right.contigName`` + bin_size: The bin size to use for the range join optimization + + Example: + >>> left = spark.createDataFrame([(1, 10)], ["start", "end"]) + >>> right = spark.createDataFrame([(2, 3)], ["start", "end"]) + >>> df = glow.left_semi_overlap_join(left, right) + + Returns: + The joined DataFrame + + """ + join_fn = SparkContext._jvm.io.projectglow.sql.LeftOverlapJoin.leftSemiJoin + fn_args = _prepare_sql_args(left, right, left_start, right_start, left_end, right_end, + extra_join_expr) + fn_args = fn_args + [bin_size] + output_jdf = join_fn(*fn_args) + return DataFrame(output_jdf, left.sparkSession) diff --git a/python/glow/sql/tests/__init__.py b/python/glow/sql/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/glow/sql/tests/test_overlap_join.py b/python/glow/sql/tests/test_overlap_join.py new file mode 100644 index 000000000..0e0b1aa88 --- /dev/null +++ b/python/glow/sql/tests/test_overlap_join.py @@ -0,0 +1,101 @@ +# Copyright 2019 The Glow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glow +import pytest + + +def test_left_join(spark): + left = spark.createDataFrame([("a", 1, 10), ("a", 2, 3), ("a", 5, 7), ("a", 2, 5), + ("b", 1, 10)], ["name", "start", "end"]) + right = spark.createDataFrame([("a", 2, 5), ("c", 2, 5)], ["name", "start", "end"]) + joined = glow.left_overlap_join(left, right, left.start, right.start, left.end, right.end, + left.name == right.name) + assert joined.count() == 5 + assert joined.where(right.start.isNull()).count() == 2 + assert joined.where((right.name == "c") & (right.start == 2) & (right.end == 5)).count() == 0 + + +def test_left_semi_join(spark): + left = spark.createDataFrame([("a", 1, 10), ("a", 2, 3), ("a", 5, 7), ("a", 2, 5), + ("b", 1, 10)], ["name", "start", "end"]) + right = spark.createDataFrame([("a", 2, 5), ("c", 1, 10)], ["name", "start", "end"]) + joined = glow.left_semi_overlap_join(left, right, left.start, right.start, left.end, right.end, + left.name == right.name) + assert joined.count() == 3 + + +@pytest.mark.parametrize('join_fn', [glow.left_overlap_join, glow.left_semi_overlap_join]) +def test_no_extra_expr(spark, join_fn): + left = spark.createDataFrame([(1, 10)], ["start", "end"]) + right = spark.createDataFrame([(1, 10)], ["start", "end"]) + joined = join_fn(left, right, left.start, right.start, left.end, right.end) + assert joined.count() == 1 + + +@pytest.mark.parametrize('join_fn', [glow.left_overlap_join, glow.left_semi_overlap_join]) +def test_bin_size(spark, join_fn): + left = spark.createDataFrame([(1, 10)], ["start", "end"]) + right = spark.createDataFrame([(1, 10)], ["start", "end"]) + joined = join_fn(left, right, left.start, right.start, left.end, right.end, bin_size=1) + assert joined.count() == 1 + + +def test_default_arguments(spark): + left = spark.createDataFrame([("a", 1, 10), ("a", 2, 3), ("a", 5, 7), ("a", 2, 5), + ("b", 1, 10)], ["contigName", "start", "end"]) + right = spark.createDataFrame([("a", 2, 5), ("c", 2, 5)], ["contigName", "start", "end"]) + joined = glow.left_overlap_join(left, right) + assert joined.count() == 5 + assert joined.where(right.start.isNull()).count() == 2 + assert joined.where((right.contigName == "c") & (right.start == 2) & + (right.end == 5)).count() == 0 + + +@pytest.mark.parametrize('join_fn', [glow.left_overlap_join, glow.left_semi_overlap_join]) +def test_default_arguments_no_contig(spark, join_fn): + left = spark.createDataFrame([(1, 10)], ["start", "end"]) + right = spark.createDataFrame([(1, 10)], ["start", "end"]) + assert join_fn(left, right).count() == 1 + + +@pytest.mark.parametrize('join_fn', [glow.left_overlap_join, glow.left_semi_overlap_join]) +def test_missing_columns(spark, join_fn): + left = spark.createDataFrame([(1, 10)], ["start", "end"]) + right = spark.createDataFrame([(1, 10)], ["start", "end"]) + args = { + 'left_start': left.start, + 'right_start': right.start, + 'left_end': left.end, + 'right_end': right.end + } + join_fn(left, right, **args) # No error + for k in args.keys(): + d = args.copy() + d.pop(k) + l = left.drop(args[k]) if 'left' in k else left + r = right.drop(args[k]) if 'right' in k else right + with pytest.raises(ValueError): + join_fn(l, r, **d) + + +def test_right_prefix(spark): + left = spark.createDataFrame([("a", 1, 10), ("a", 2, 3), ("a", 5, 7), ("a", 2, 5), + ("b", 1, 10)], ["name", "start", "end"]) + right = spark.createDataFrame([("a", 2, 5, 'dog'), ("c", 2, 5, 'cat')], + ["name", "start", "end", "animal"]) + joined = glow.left_overlap_join(left, right, left.start, right.start, left.end, right.end, + left.name == right.name, 'ann_') + assert all([c in joined.columns for c in ['ann_name', 'ann_start', 'ann_end', 'ann_animal']]) + assert joined.count() == 5 diff --git a/stable-version.txt b/stable-version.txt index 6085e9465..227cea215 100644 --- a/stable-version.txt +++ b/stable-version.txt @@ -1 +1 @@ -1.2.1 +2.0.0 diff --git a/version.sbt b/version.sbt index 106e75757..7bf954405 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -ThisBuild / version := "2.0.0-SNAPSHOT" +version in ThisBuild := "2.0.0"