Skip to content

Commit

Permalink
Merge 465996a into 9b96f28
Browse files Browse the repository at this point in the history
  • Loading branch information
Victor Delépine authored Aug 31, 2021
2 parents 9b96f28 + 465996a commit 080fab0
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class LargeHashSCollectionFunctions[T](private val self: SCollection[T]) {
*
* @group transform
*/
def hashFilter(sideInput: SideInput[SparkeySet[T]]): SCollection[T] = {
def largeHashFilter(sideInput: SideInput[SparkeySet[T]]): SCollection[T] = {
implicit val coder = self.coder
self.map((_, ())).hashIntersectByKey(sideInput).keys
self.map((_, ())).largeHashIntersectByKey(sideInput).keys
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
compressionBlockSize: Int = DefaultCompressionBlockSize
): SCollection[(K, (V, W))] = {
implicit val wCoder: Coder[W] = rhs.valueCoder
hashJoin(rhs.asLargeMultiMapSideInput(numShards, compressionType, compressionBlockSize))
largeHashJoin(rhs.asLargeMultiMapSideInput(numShards, compressionType, compressionBlockSize))
}

/**
Expand All @@ -69,7 +69,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*
* @group join
*/
def hashJoin[W: Coder](
def largeHashJoin[W: Coder](
sideInput: SideInput[SparkeyMap[K, Iterable[W]]]
): SCollection[(K, (V, W))] =
self.transform { in =>
Expand Down Expand Up @@ -102,7 +102,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
compressionBlockSize: Int = DefaultCompressionBlockSize
): SCollection[(K, (V, Option[W]))] = {
implicit val wCoder: Coder[W] = rhs.valueCoder
hashLeftOuterJoin(
largeHashLeftOuterJoin(
rhs.asLargeMultiMapSideInput(numShards, compressionType, compressionBlockSize)
)
}
Expand All @@ -118,7 +118,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
* }}}
* @group join
*/
def hashLeftOuterJoin[W: Coder](
def largeHashLeftOuterJoin[W: Coder](
sideInput: SideInput[SparkeyMap[K, Iterable[W]]]
): SCollection[(K, (V, Option[W]))] = {
self.transform { in =>
Expand Down Expand Up @@ -146,7 +146,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
compressionBlockSize: Int = DefaultCompressionBlockSize
): SCollection[(K, (Option[V], Option[W]))] = {
implicit val wCoder = rhs.valueCoder
hashFullOuterJoin(
largeHashFullOuterJoin(
rhs.asLargeMultiMapSideInput(numShards, compressionType, compressionBlockSize)
)
}
Expand All @@ -163,7 +163,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*
* @group join
*/
def hashFullOuterJoin[W: Coder](
def largeHashFullOuterJoin[W: Coder](
sideInput: SideInput[SparkeyMap[K, Iterable[W]]]
): SCollection[(K, (Option[V], Option[W]))] =
self.transform { in =>
Expand Down Expand Up @@ -210,7 +210,9 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
): SCollection[(K, V)] =
hashIntersectByKey(rhs.asLargeSetSideInput(numShards, compressionType, compressionBlockSize))
largeHashIntersectByKey(
rhs.asLargeSetSideInput(numShards, compressionType, compressionBlockSize)
)

/**
* Return an SCollection with the pairs from `this` whose keys are in the SideSet `rhs`.
Expand All @@ -220,7 +222,7 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
* @group per
* key
*/
def hashIntersectByKey(sideInput: SideInput[SparkeySet[K]]): SCollection[(K, V)] =
def largeHashIntersectByKey(sideInput: SideInput[SparkeySet[K]]): SCollection[(K, V)] =
self
.withSideInputs(sideInput)
.filter { case ((k, _), sideInputCtx) => sideInputCtx(sideInput).contains(k) }
Expand All @@ -240,15 +242,17 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
): SCollection[(K, V)] =
hashSubtractByKey(rhs.asLargeSetSideInput(numShards, compressionType, compressionBlockSize))
largeHashSubtractByKey(
rhs.asLargeSetSideInput(numShards, compressionType, compressionBlockSize)
)

/**
* Return an SCollection with the pairs from `this` whose keys are not in SideInput[Set] `rhs`.
*
* @group per
* key
*/
def hashSubtractByKey(sideInput: SideInput[SparkeySet[K]]): SCollection[(K, V)] =
def largeHashSubtractByKey(sideInput: SideInput[SparkeySet[K]]): SCollection[(K, V)] =
self
.withSideInputs(sideInput)
.filter { case ((k, _), sideInputCtx) => !sideInputCtx(sideInput).contains(k) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq("a", "b", "c", "b"))
val p2 = sc.parallelize(Seq[String]("a", "a", "b", "e")).asLargeSetSideInput
val p = p1.hashFilter(p2)
val p = p1.largeHashFilter(p2)
p should containInAnyOrder(Seq("a", "b", "b"))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("b", 13))).asLargeMultiMapSideInput
val p = p1.hashJoin(p2)
val p = p1.largeHashJoin(p2)
p should
containInAnyOrder(Seq(("a", (1, 11)), ("a", (2, 11)), ("b", (3, 12)), ("b", (3, 13))))
}
Expand All @@ -77,7 +77,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3)))
val p2 = sc.parallelize[(String, Int)](Map.empty).asLargeMultiMapSideInput
val p = p1.hashJoin(p2)
val p = p1.largeHashJoin(p2)
p should
containInAnyOrder(Seq.empty[(String, (Int, Int))])
}
Expand Down Expand Up @@ -123,7 +123,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("d", 14))).asLargeMultiMapSideInput
val p = p1.hashLeftOuterJoin(p2)
val p = p1.largeHashLeftOuterJoin(p2)
p should containInAnyOrder(Seq(("a", (1, Some(11))), ("b", (2, Some(12))), ("c", (3, None))))
}
}
Expand Down Expand Up @@ -181,7 +181,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("b", 2)))
val p2 = sc.parallelize(Seq(("a", 11), ("c", 13))).asLargeMultiMapSideInput
val p = p1.hashFullOuterJoin(p2)
val p = p1.largeHashFullOuterJoin(p2)
p should containInAnyOrder(
Seq(("a", (Some(1), Some(11))), ("b", (Some(2), None)), ("c", (None, Some(13))))
)
Expand Down Expand Up @@ -228,7 +228,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("b", 2), ("c", 3), ("b", 4)))
val p2 = sc.parallelize(Seq[String]("a", "b", "d")).asLargeSetSideInput
val p = p1.hashIntersectByKey(p2)
val p = p1.largeHashIntersectByKey(p2)
p should containInAnyOrder(Seq(("a", 1), ("b", 2), ("b", 4)))
}
}
Expand Down Expand Up @@ -276,7 +276,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("b", 2), ("b", 3), ("c", 4)))
val p2 = sc.parallelize(Seq[String]("a", "b")).asLargeSetSideInput
val output = p1.hashSubtractByKey(p2)
val output = p1.largeHashSubtractByKey(p2)
output should haveSize(1)
output should containInAnyOrder(Seq(("c", 4)))
}
Expand All @@ -286,7 +286,7 @@ class PairLargeHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq("a", "b", "c", "b"))
val p2 = sc.parallelize(Seq[String]("a", "a", "b", "e")).asLargeSetSideInput
val p = p1.hashFilter(p2)
val p = p1.largeHashFilter(p2)
p should containInAnyOrder(Seq("a", "b", "b"))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

package com.spotify.scio.extra.sparkey

import java.io.File
import java.nio.file.Files
import java.util.Arrays

import com.github.benmanes.caffeine.cache.{Cache => CCache, Caffeine}
import com.spotify.scio._
import com.spotify.scio.testing._
Expand All @@ -29,6 +25,9 @@ import com.spotify.sparkey._
import org.apache.beam.sdk.io.FileSystems
import org.apache.commons.io.FileUtils

import java.io.File
import java.nio.file.Files
import java.util.Arrays
import scala.jdk.CollectionConverters._

final case class TestCache[K, V](testId: String) extends CacheT[K, V, CCache[K, V]] {
Expand Down Expand Up @@ -664,4 +663,24 @@ class SparkeyTest extends PipelineSpec {
.basePath
FileUtils.deleteDirectory(new File(basePath))
}

it should "not override the regular hashJoin method" in {

val sc = ScioContext()

val lhsInput = Seq((1, "a"), (2, "c"), (3, "e"), (4, "g"))
val rhsInput = Seq((1, "b"), (2, "d"), (3, "f"))

val rhs = sc.parallelize(rhsInput)
val lhs = sc.parallelize(lhsInput)

val result = lhs
.hashJoin(rhs)
.materialize

val scioResult = sc.run().waitUntilFinish()
val expectedOutput = List((1, ("a", "b")), (2, ("c", "d")), (3, ("e", "f")))

scioResult.tap(result).value.toList should contain theSameElementsAs expectedOutput
}
}

0 comments on commit 080fab0

Please sign in to comment.