Skip to content

Commit

Permalink
[SPARK-29007][STREAMING][MLLIB][TESTS] Enforce not leaking SparkConte…
Browse files Browse the repository at this point in the history
…xt in tests which creates new StreamingContext with new SparkContext

### What changes were proposed in this pull request?

This patch enforces tests to prevent leaking newly created SparkContext while is created via initializing StreamingContext. Leaking SparkContext in test would make most of following tests being failed as well, so this patch applies defensive programming, trying its best to ensure SparkContext is cleaned up.

### Why are the changes needed?

We got some case in CI build where SparkContext is being leaked and other tests are affected by leaked SparkContext. Ideally we should isolate the environment among tests if possible.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Modified UTs.

Closes #25709 from HeartSaVioR/SPARK-29007.

Authored-by: Jungtaek Lim (HeartSaVioR) <[email protected]>
Signed-off-by: Marcelo Vanzin <[email protected]>
  • Loading branch information
HeartSaVioR authored and Marcelo Vanzin committed Sep 11, 2019
1 parent 2736efa commit b62ef8f
Show file tree
Hide file tree
Showing 21 changed files with 240 additions and 241 deletions.
7 changes: 7 additions & 0 deletions external/kafka-0-10/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.streaming.kafka010

import java.io.File
import java.lang.{ Long => JLong }
import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID }
import java.lang.{Long => JLong}
import java.util.{Arrays, HashMap => JHashMap, Map => JMap, UUID}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicLong
Expand All @@ -31,22 +31,20 @@ import scala.util.Random
import org.apache.kafka.clients.consumer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.StringDeserializer
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.{LocalStreamingContext, Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.scheduler._
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.Utils

class DirectKafkaStreamSuite
extends SparkFunSuite
with BeforeAndAfter
with BeforeAndAfterAll
with LocalStreamingContext
with Eventually
with Logging {
val sparkConf = new SparkConf()
Expand All @@ -56,7 +54,6 @@ class DirectKafkaStreamSuite
// Otherwise the poll timeout defaults to 2 minutes and causes test cases to run longer.
.set("spark.streaming.kafka.consumer.poll.ms", "10000")

private var ssc: StreamingContext = _
private var testDir: File = _

private var kafkaTestUtils: KafkaTestUtils = _
Expand All @@ -78,12 +75,13 @@ class DirectKafkaStreamSuite
}
}

after {
if (ssc != null) {
ssc.stop(stopSparkContext = true)
}
if (testDir != null) {
Utils.deleteRecursively(testDir)
override def afterEach(): Unit = {
try {
if (testDir != null) {
Utils.deleteRecursively(testDir)
}
} finally {
super.afterEach()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
import org.apache.spark.streaming.{LocalStreamingContext, _}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest
import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._
Expand All @@ -40,7 +40,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
import org.apache.spark.util.Utils

abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite
with Eventually with BeforeAndAfter with BeforeAndAfterAll {
with LocalStreamingContext with Eventually with BeforeAndAfter with BeforeAndAfterAll {

// This is the name that KCL will use to save metadata to DynamoDB
private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
Expand All @@ -53,15 +53,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
private val dummyAWSSecretKey = "dummySecretKey"

private var testUtils: KinesisTestUtils = null
private var ssc: StreamingContext = null
private var sc: SparkContext = null

override def beforeAll(): Unit = {
val conf = new SparkConf()
.setMaster("local[4]")
.setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
sc = new SparkContext(conf)

runIfTestsEnabled("Prepare KinesisTestUtils") {
testUtils = new KPLBasedKinesisTestUtils()
testUtils.createStream()
Expand All @@ -70,12 +64,6 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun

override def afterAll(): Unit = {
try {
if (ssc != null) {
ssc.stop()
}
if (sc != null) {
sc.stop()
}
if (testUtils != null) {
// Delete the Kinesis stream as well as the DynamoDB table generated by
// Kinesis Client Library when consuming the stream
Expand All @@ -87,17 +75,22 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
}
}

before {
override def beforeEach(): Unit = {
super.beforeEach()
val conf = new SparkConf()
.setMaster("local[4]")
.setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
sc = new SparkContext(conf)
ssc = new StreamingContext(sc, batchDuration)
}

after {
if (ssc != null) {
ssc.stop(stopSparkContext = false)
ssc = null
}
if (testUtils != null) {
testUtils.deleteDynamoDBTable(appName)
override def afterEach(): Unit = {
try {
if (testUtils != null) {
testUtils.deleteDynamoDBTable(appName)
}
} finally {
super.afterEach()
}
}

Expand Down
7 changes: 7 additions & 0 deletions mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,17 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream

class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
class StreamingLogisticRegressionSuite
extends SparkFunSuite
with LocalStreamingContext
with TestSuiteBase {

// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000

var ssc: StreamingContext = _

override def afterFunction() {
super.afterFunction()
if (ssc != null) {
ssc.stop()
}
}

// Test if we can accurately learn B for Y = logistic(BX) on streaming data
test("parameter accuracy") {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,14 @@ package org.apache.spark.mllib.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom

class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
class StreamingKMeansSuite extends SparkFunSuite with LocalStreamingContext with TestSuiteBase {

override def maxWaitTimeMillis: Int = 30000

var ssc: StreamingContext = _

override def afterFunction() {
super.afterFunction()
if (ssc != null) {
ssc.stop()
}
}

test("accuracy for single center and equivalence to grand average") {
// set parameters
val numBatches = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,17 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream

class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
class StreamingLinearRegressionSuite
extends SparkFunSuite
with LocalStreamingContext
with TestSuiteBase {

// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000

var ssc: StreamingContext = _

override def afterFunction() {
super.afterFunction()
if (ssc != null) {
ssc.stop()
}
}

// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
def errorMessage = v1.toString + " did not equal " + v2.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties,
Utils}
import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties, Utils}

/**
* A input stream that records the times of restore() invoked
Expand Down Expand Up @@ -206,24 +205,21 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
class CheckpointSuite extends TestSuiteBase with LocalStreamingContext with DStreamCheckpointTester
with ResetSystemProperties {

var ssc: StreamingContext = null

override def batchDuration: Duration = Milliseconds(500)

override def beforeFunction() {
super.beforeFunction()
override def beforeEach(): Unit = {
super.beforeEach()
Utils.deleteRecursively(new File(checkpointDir))
}

override def afterFunction() {
override def afterEach(): Unit = {
try {
if (ssc != null) { ssc.stop() }
Utils.deleteRecursively(new File(checkpointDir))
} finally {
super.afterFunction()
super.afterEach()
}
}

Expand Down Expand Up @@ -255,7 +251,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
.checkpoint(stateStreamCheckpointInterval)
.map(t => (t._1, t._2))
}
var ssc = setupStreams(input, operation)
ssc = setupStreams(input, operation)
var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head

def waitForCompletionOfBatch(numBatches: Long): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,14 @@ import org.apache.spark.util.ReturnStatementInClosureException
/**
* Test that closures passed to DStream operations are actually cleaned.
*/
class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
private var ssc: StreamingContext = null
class DStreamClosureSuite extends SparkFunSuite with LocalStreamingContext with BeforeAndAfterAll {
override protected def beforeEach(): Unit = {
super.beforeEach()

override def beforeAll(): Unit = {
super.beforeAll()
val sc = new SparkContext("local", "test")
ssc = new StreamingContext(sc, Seconds(1))
}

override def afterAll(): Unit = {
try {
ssc.stop(stopSparkContext = true)
ssc = null
} finally {
super.afterAll()
}
}

test("user provided closures are actually cleaned") {
val dstream = new DummyInputDStream(ssc)
val pairDstream = dstream.map { i => (i, i) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,29 @@ import org.apache.spark.util.ManualClock
/**
* Tests whether scope information is passed from DStream operations to RDDs correctly.
*/
class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
private var ssc: StreamingContext = null
private val batchDuration: Duration = Seconds(1)
class DStreamScopeSuite
extends SparkFunSuite
with LocalStreamingContext {

override def beforeEach(): Unit = {
super.beforeEach()

override def beforeAll(): Unit = {
super.beforeAll()
val conf = new SparkConf().setMaster("local").setAppName("test")
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
val batchDuration: Duration = Seconds(1)
ssc = new StreamingContext(new SparkContext(conf), batchDuration)

assertPropertiesNotSet()
}

override def afterAll(): Unit = {
override def afterEach(): Unit = {
try {
ssc.stop(stopSparkContext = true)
assertPropertiesNotSet()
} finally {
super.afterAll()
super.afterEach()
}
}

before { assertPropertiesNotSet() }
after { assertPropertiesNotSet() }

test("dstream without scope") {
val dummyStream = new DummyDStream(ssc)
dummyStream.initialize(Time(0))
Expand Down
Loading

0 comments on commit b62ef8f

Please sign in to comment.