From 4d5965ecb48685faed63a751100433a273695e5b Mon Sep 17 00:00:00 2001 From: Nikita Gorbachevsky Date: Tue, 13 Aug 2019 14:36:23 +0300 Subject: [PATCH] [SPARK-28709][DSTREAMS] - Fix StreamingContext leak through StreamingJobProgressListener on stop --- .../scala/org/apache/spark/ui/SparkUI.scala | 3 ++ .../scala/org/apache/spark/ui/WebUI.scala | 1 + .../spark/streaming/StreamingContext.scala | 29 +++++++++++++++---- .../spark/streaming/ui/StreamingTab.scala | 28 +++++------------- .../spark/streaming/InputStreamsSuite.scala | 4 --- .../streaming/StreamingContextSuite.scala | 24 +++++++++++++++ .../spark/streaming/UISeleniumSuite.scala | 4 +++ 7 files changed, 63 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 1175bc25de454..6fb8e458a789c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -138,6 +138,9 @@ private[spark] class SparkUI private ( streamingJobProgressListener = Option(sparkListener) } + def clearStreamingJobProgressListener(): Unit = { + streamingJobProgressListener = None + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 54ae258ba565f..1fe822a0e3b57 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -93,6 +93,7 @@ private[spark] abstract class WebUI( attachHandler(renderJsonHandler) val handlers = pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) handlers += renderHandler + handlers += renderJsonHandler } /** Attaches a handler to this UI. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 15ebef2b325c1..48913eaa4a8bf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -38,7 +38,6 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.UI._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.SerializationDebugger @@ -189,10 +188,9 @@ class StreamingContext private[streaming] ( private[streaming] val progressListener = new StreamingJobProgressListener(this) private[streaming] val uiTab: Option[StreamingTab] = - if (conf.get(UI_ENABLED)) { - Some(new StreamingTab(this)) - } else { - None + sparkContext.ui match { + case Some(ui) => Some(new StreamingTab(this, ui)) + case None => None } /* Initializing a streamingSource to register metrics */ @@ -511,6 +509,10 @@ class StreamingContext private[streaming] ( scheduler.listenerBus.addListener(streamingListener) } + def removeStreamingListener(streamingListener: StreamingListener): Unit = { + scheduler.listenerBus.removeListener(streamingListener) + } + private def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -575,6 +577,8 @@ class StreamingContext private[streaming] ( try { validate() + registerProgressListener() + // Start the streaming scheduler in a new thread, so that thread local properties // like call sites and job groups can be reset without affecting those of the // current thread. @@ -690,6 +694,9 @@ class StreamingContext private[streaming] ( Utils.tryLogNonFatalError { uiTab.foreach(_.detach()) } + Utils.tryLogNonFatalError { + unregisterProgressListener() + } StreamingContext.setActiveContext(null) Utils.tryLogNonFatalError { waiter.notifyStop() @@ -716,6 +723,18 @@ class StreamingContext private[streaming] ( // Do not stop SparkContext, let its own shutdown hook stop it stop(stopSparkContext = false, stopGracefully = stopGracefully) } + + private def registerProgressListener(): Unit = { + addStreamingListener(progressListener) + sc.addSparkListener(progressListener) + sc.ui.foreach(_.setStreamingJobProgressListener(progressListener)) + } + + private def unregisterProgressListener(): Unit = { + removeStreamingListener(progressListener) + sc.removeSparkListener(progressListener) + sc.ui.foreach(_.clearStreamingJobProgressListener()) + } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 25e71258b9369..13357db728701 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.ui -import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.streaming.StreamingContext import org.apache.spark.ui.{SparkUI, SparkUITab} @@ -26,37 +25,24 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} * Spark Web UI tab that shows statistics of a streaming job. * This assumes the given SparkContext has enabled its SparkUI. */ -private[spark] class StreamingTab(val ssc: StreamingContext) - extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging { - - import StreamingTab._ +private[spark] class StreamingTab(val ssc: StreamingContext, sparkUI: SparkUI) + extends SparkUITab(sparkUI, "streaming") with Logging { private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" - val parent = getSparkUI(ssc) + val parent = sparkUI val listener = ssc.progressListener - ssc.addStreamingListener(listener) - ssc.sc.addSparkListener(listener) - parent.setStreamingJobProgressListener(listener) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) def attach() { - getSparkUI(ssc).attachTab(this) - getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") + parent.attachTab(this) + parent.addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") } def detach() { - getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).detachHandler("/static/streaming") - } -} - -private object StreamingTab { - def getSparkUI(ssc: StreamingContext): SparkUI = { - ssc.sc.ui.getOrElse { - throw new SparkException("Parent SparkUI to attach this tab to not found!") - } + parent.detachTab(this) + parent.detachHandler("/static/streaming") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 035ed4aa51bb7..0792770442055 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,8 +52,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - ssc.addStreamingListener(ssc.progressListener) - val input = Seq(1, 2, 3, 4, 5) // Use "batchCount" to make sure we check the result after all batches finish val batchCounter = new BatchCounter(ssc) @@ -106,8 +104,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { testServer.start() withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - ssc.addStreamingListener(ssc.progressListener) - val batchCounter = new BatchCounter(ssc) val networkStream = ssc.socketTextStream( "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5cda6f9925455..56c7cbf0e7bb8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -34,6 +34,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel @@ -392,6 +393,29 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL assert(!sourcesAfterStop.contains(streamingSourceAfterStop)) } + test("SPARK-28709 registering and de-registering of progressListener") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set(UI_ENABLED, true) + + ssc = new StreamingContext(conf, batchDuration) + + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + val sparkUI = ssc.sc.ui.get + + addInputStream(ssc).register() + ssc.start() + + assert(ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener)) + assert(ssc.sc.listenerBus.listeners.contains(ssc.progressListener)) + assert(sparkUI.getStreamingJobProgressListener.get == ssc.progressListener) + + ssc.stop() + + assert(!ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener)) + assert(!ssc.sc.listenerBus.listeners.contains(ssc.progressListener)) + assert(sparkUI.getStreamingJobProgressListener.isEmpty) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 483a7519873e9..1d34221fde4f4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -97,6 +97,8 @@ class UISeleniumSuite val sparkUI = ssc.sparkContext.ui.get + sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (5) + eventually(timeout(10.seconds), interval(50.milliseconds)) { go to (sparkUI.webUrl.stripSuffix("/")) find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None) @@ -196,6 +198,8 @@ class UISeleniumSuite ssc.stop(false) + sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (0) + eventually(timeout(10.seconds), interval(50.milliseconds)) { go to (sparkUI.webUrl.stripSuffix("/")) find(cssSelector( """ul li a[href*="streaming"]""")) should be(None)