Skip to content

Commit

Permalink
edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jun 13, 2014
1 parent 3de882b commit 444e750
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,15 @@ abstract class RDD[T: ClassTag](
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
val numStDev = 10.0
val initialCount = this.count()

if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
} else if (num == 0) {
return new Array[T](0)
}

if (initialCount == 0 || num == 0) {
val initialCount = this.count()
if (initialCount == 0) {
return new Array[T](0)
}

Expand All @@ -407,7 +409,7 @@ abstract class RDD[T: ClassTag](
}

val rand = new Random(seed)
if (!withReplacement && num > initialCount) {
if (!withReplacement && num >= initialCount) {
return Utils.randomizeInPlace(this.collect(), rand)
}

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def takeSample(self, withReplacement, num, seed=None):
>>> len(rdd.takeSample(False, 15, 3))
10
"""
numStDev = 10.0

if num < 0:
raise ValueError("Sample size cannot be negative.")
Expand All @@ -388,7 +389,6 @@ def takeSample(self, withReplacement, num, seed=None):
rand.shuffle(samples)
return samples

numStDev = 10.0
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
Expand Down

0 comments on commit 444e750

Please sign in to comment.