diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 61be2bc4eb994..d27f390a23f95 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.execution -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicReference import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,7 +41,8 @@ import org.apache.spark.util.Utils private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { /** The thread state. */ - private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted) + private val state: AtomicReference[ThreadStateInfo] = new AtomicReference( + ThreadState.notStarted) // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, @@ -349,17 +350,20 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends private object ThreadState { /** The thread has not started: transition to interrupted or started. */ - val notStarted: Int = 0 + val notStarted: ThreadStateInfo = ThreadStateInfo(0) /** Execution was interrupted: terminal state. */ - val interrupted: Int = 1 + val interrupted: ThreadStateInfo = ThreadStateInfo(1) /** The thread has started: transition to startedInterrupted or completed. */ - val started: Int = 2 + val started: ThreadStateInfo = ThreadStateInfo(2) - /** The thread has started and execution was interrupted: transition to completed. */ - val startedInterrupted: Int = 3 + /** The thread was started and execution has been interrupted: transition to completed. */ + val startedInterrupted: ThreadStateInfo = ThreadStateInfo(3) - /** Execution was completed: terminal state. */ - val completed: Int = 4 + /** Execution has been completed: terminal state. */ + val completed: ThreadStateInfo = ThreadStateInfo(4) } + +/** Represents the state of an execution thread. */ +case class ThreadStateInfo(val transitionState: Int) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index d9eb5438c3886..f750ca6db67a8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -21,7 +21,6 @@ import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.{AtomicLong, AtomicReference} -import scala.collection.mutable import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -160,19 +159,14 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { - var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]() executions.forEach((_, executeHolder) => { if (executeHolder.sessionHolder.key == key) { - sessionExecutionHolders += executeHolder + val info = executeHolder.getExecuteInfo + logInfo( + log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.") + removeExecuteHolder(executeHolder.key, abandoned = true) } }) - - sessionExecutionHolders.foreach { executeHolder => - val info = executeHolder.getExecuteInfo - logInfo( - log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.") - removeExecuteHolder(executeHolder.key, abandoned = true) - } } /** Get info about abandoned execution, if there is one. */ @@ -252,30 +246,24 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // Visible for testing. private[connect] def periodicMaintenance(timeout: Long): Unit = { + // Find any detached executions that expired and should be removed. logInfo("Started periodic run of SparkConnectExecutionManager maintenance.") - // Find any detached executions that expired and should be removed. - val toRemove = new mutable.ArrayBuffer[ExecuteHolder]() val nowMs = System.currentTimeMillis() - executions.forEach((_, executeHolder) => { executeHolder.lastAttachedRpcTimeMs match { case Some(detached) => if (detached + timeout <= nowMs) { - toRemove += executeHolder + val info = executeHolder.getExecuteInfo + logInfo( + log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " + + log"and expired and will be removed.") + removeExecuteHolder(executeHolder.key, abandoned = true) } case _ => // execution is active } }) - // .. and remove them. - toRemove.foreach { executeHolder => - val info = executeHolder.getExecuteInfo - logInfo( - log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " + - log"and expired and will be removed.") - removeExecuteHolder(executeHolder.key, abandoned = true) - } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index 4ca3a80bfb985..a306856efa33c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -21,7 +21,6 @@ import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import scala.collection.mutable import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -226,9 +225,8 @@ class SparkConnectSessionManager extends Logging { private def periodicMaintenance( defaultInactiveTimeoutMs: Long, ignoreCustomTimeout: Boolean): Unit = { - logInfo("Started periodic run of SparkConnectSessionManager maintenance.") // Find any sessions that expired and should be removed. - val toRemove = new mutable.ArrayBuffer[SessionHolder]() + logInfo("Started periodic run of SparkConnectSessionManager maintenance.") def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = { val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && !ignoreCustomTimeout) { @@ -242,15 +240,8 @@ class SparkConnectSessionManager extends Logging { val nowMs = System.currentTimeMillis() sessionStore.forEach((_, sessionHolder) => { - if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) { - toRemove += sessionHolder - } - }) - - // .. and remove them. - toRemove.foreach { sessionHolder => val info = sessionHolder.getSessionHolderInfo - if (shouldExpire(info, System.currentTimeMillis())) { + if (shouldExpire(info, nowMs)) { logInfo( log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + log"and will be closed.") @@ -261,7 +252,8 @@ class SparkConnectSessionManager extends Logging { case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) } } - } + }) + logInfo("Finished periodic run of SparkConnectSessionManager maintenance.") }