Skip to content

Commit

Permalink
Add test for flatMapWith()
Browse files Browse the repository at this point in the history
  • Loading branch information
tedyu committed May 7, 2015
1 parent 6c124a9 commit 6846e40
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
9 changes: 6 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,9 @@ abstract class RDD[T: ClassTag](
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = withScope {
val cleanF = sc.clean(f)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
val a = cleanA(index)
iter.map(t => cleanF(t, a))
}, preservesPartitioning)
}
Expand All @@ -758,8 +759,9 @@ abstract class RDD[T: ClassTag](
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
val a = cleanA(index)
iter.flatMap(t => cleanF(t, a))
}, preservesPartitioning)
}
Expand All @@ -772,8 +774,9 @@ abstract class RDD[T: ClassTag](
@deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope {
val cleanF = sc.clean(f)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex { (index, iter) =>
val a = constructA(index)
val a = cleanA(index)
iter.map(t => {cleanF(t, a); t})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ClosureCleanerSuite extends FunSuite {
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
Expand Down Expand Up @@ -260,6 +261,16 @@ private object TestUserClosuresActuallyCleaned {
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
}
def testFlatMapWith(rdd: RDD[Int]): Unit = {
import java.util.Random
val randoms = rdd.flatMapWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) =>
val random = prng.nextDouble()
Seq(random * t, random * t * 10)}.
count()
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
}
def testZipPartitions2(rdd: RDD[Int]): Unit = {
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
}
Expand Down

0 comments on commit 6846e40

Please sign in to comment.