Skip to content

Commit

Permalink
Update test to check for existence of command, add a getPipeEnvVars f…
Browse files Browse the repository at this point in the history
…unction to HadoopRDD
  • Loading branch information
tgravescs committed Mar 7, 2014
1 parent e3401dc commit cc97a6a
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 78 deletions.
19 changes: 19 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.rdd

import java.io.EOFException
import scala.collection.immutable.Map

import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
Expand All @@ -43,6 +45,23 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
override def hashCode(): Int = 41 * (41 + rddId) + idx

override val index: Int = idx

/**
* Get any environment variables that should be added to the users environment when running pipes
* @return a Map with the environment variables and corresponding values, it could be empty
*/
def getPipeEnvVars(): Map[String, String] = {
val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
// map_input_file is deprecated in favor of mapreduce_map_input_file but set both
// since its not removed yet
Map("map_input_file" -> is.getPath().toString(),
"mapreduce_map_input_file" -> is.getPath().toString())
} else {
Map()
}
envVars
}
}

/**
Expand Down
10 changes: 1 addition & 9 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag

import org.apache.hadoop.mapred.FileSplit
import org.apache.spark.{Partition, SparkEnv, TaskContext}


Expand Down Expand Up @@ -65,14 +64,7 @@ class PipedRDD[T: ClassTag](
// so the user code can access the input filename
if (split.isInstanceOf[HadoopPartition]) {
val hadoopSplit = split.asInstanceOf[HadoopPartition]

if (hadoopSplit.inputSplit.value.isInstanceOf[FileSplit]) {
val is: FileSplit = hadoopSplit.inputSplit.value.asInstanceOf[FileSplit]
// map.input.file is deprecated in favor of mapreduce.map.input.file but set both
// since its not removed yet
currentEnvVars.put("map_input_file", is.getPath().toString())
currentEnvVars.put("mapreduce_map_input_file", is.getPath().toString())
}
currentEnvVars.putAll(hadoopSplit.getPipeEnvVars())
}

val proc = pb.start()
Expand Down
175 changes: 106 additions & 69 deletions core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,75 +25,101 @@ import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
import org.apache.hadoop.fs.Path

import scala.collection.Map
import scala.sys.process._
import scala.util.Try
import org.apache.hadoop.io.{Text, LongWritable}

class PipedRDDSuite extends FunSuite with SharedSparkContext {

test("basic pipe") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)

val piped = nums.pipe(Seq("cat"))

val c = piped.collect()
assert(c.size === 4)
assert(c(0) === "1")
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
if (testCommandAvailable("cat")) {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)

val piped = nums.pipe(Seq("cat"))

val c = piped.collect()
assert(c.size === 4)
assert(c(0) === "1")
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
} else {
assert(true)
}
}

test("advanced pipe") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val bl = sc.broadcast(List("0"))

val piped = nums.pipe(Seq("cat"),
Map[String, String](),
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
(i:Int, f: String=> Unit) => f(i + "_"))

val c = piped.collect()

assert(c.size === 8)
assert(c(0) === "0")
assert(c(1) === "\u0001")
assert(c(2) === "1_")
assert(c(3) === "2_")
assert(c(4) === "0")
assert(c(5) === "\u0001")
assert(c(6) === "3_")
assert(c(7) === "4_")

val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
val d = nums1.groupBy(str=>str.split("\t")(0)).
pipe(Seq("cat"),
Map[String, String](),
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
(i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
assert(d.size === 8)
assert(d(0) === "0")
assert(d(1) === "\u0001")
assert(d(2) === "b\t2_")
assert(d(3) === "b\t4_")
assert(d(4) === "0")
assert(d(5) === "\u0001")
assert(d(6) === "a\t1_")
assert(d(7) === "a\t3_")
if (testCommandAvailable("cat")) {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val bl = sc.broadcast(List("0"))

val piped = nums.pipe(Seq("cat"),
Map[String, String](),
(f: String => Unit) => {
bl.value.map(f(_)); f("\u0001")
},
(i: Int, f: String => Unit) => f(i + "_"))

val c = piped.collect()

assert(c.size === 8)
assert(c(0) === "0")
assert(c(1) === "\u0001")
assert(c(2) === "1_")
assert(c(3) === "2_")
assert(c(4) === "0")
assert(c(5) === "\u0001")
assert(c(6) === "3_")
assert(c(7) === "4_")

val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
val d = nums1.groupBy(str => str.split("\t")(0)).
pipe(Seq("cat"),
Map[String, String](),
(f: String => Unit) => {
bl.value.map(f(_)); f("\u0001")
},
(i: Tuple2[String, Seq[String]], f: String => Unit) => {
for (e <- i._2) {
f(e + "_")
}
}).collect()
assert(d.size === 8)
assert(d(0) === "0")
assert(d(1) === "\u0001")
assert(d(2) === "b\t2_")
assert(d(3) === "b\t4_")
assert(d(4) === "0")
assert(d(5) === "\u0001")
assert(d(6) === "a\t1_")
assert(d(7) === "a\t3_")
} else {
assert(true)
}
}

test("pipe with env variable") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
assert(c.size === 2)
assert(c(0) === "LALALA")
assert(c(1) === "LALALA")
if (testCommandAvailable("printenv")) {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
assert(c.size === 2)
assert(c(0) === "LALALA")
assert(c(1) === "LALALA")
} else {
assert(true)
}
}

test("pipe with non-zero exit status") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
intercept[SparkException] {
piped.collect()
if (testCommandAvailable("cat")) {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
intercept[SparkException] {
piped.collect()
}
} else {
assert(true)
}
}

Expand All @@ -105,23 +131,34 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
testExportInputFile("mapreduce_map_input_file")
}

def testExportInputFile(varName:String) {
val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
def testCommandAvailable(command: String): Boolean = {
Try(Process(command) !!).isSuccess
}

def testExportInputFile(varName: String) {
if (testCommandAvailable("printenv")) {
val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
classOf[Text], 2) {
override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())
override val getDependencies = List[Dependency[_]]()
override def compute(theSplit: Partition, context: TaskContext) = {
new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
new Text("b"))))
override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())

override val getDependencies = List[Dependency[_]]()

override def compute(theSplit: Partition, context: TaskContext) = {
new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
new Text("b"))))
}
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
} else {
// printenv isn't available so just pass the test
assert(true)
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
}

def generateFakeHadoopPartition(): HadoopPartition = {
Expand Down

0 comments on commit cc97a6a

Please sign in to comment.