Skip to content

Commit

Permalink
Refactor and optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
changgyoopark-db committed Nov 7, 2024
1 parent 9858ab6 commit 03048c4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.connect.execution

import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
Expand All @@ -41,7 +41,8 @@ import org.apache.spark.util.Utils
private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging {

/** The thread state. */
private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted)
private val state: AtomicReference[ThreadStateInfo] = new AtomicReference(
ThreadState.notStarted)

// The newly created thread will inherit all InheritableThreadLocals used by Spark,
// e.g. SparkContext.localProperties. If considering implementing a thread-pool,
Expand Down Expand Up @@ -349,17 +350,20 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
private object ThreadState {

/** The thread has not started: transition to interrupted or started. */
val notStarted: Int = 0
val notStarted: ThreadStateInfo = ThreadStateInfo(0)

/** Execution was interrupted: terminal state. */
val interrupted: Int = 1
val interrupted: ThreadStateInfo = ThreadStateInfo(1)

/** The thread has started: transition to startedInterrupted or completed. */
val started: Int = 2
val started: ThreadStateInfo = ThreadStateInfo(2)

/** The thread has started and execution was interrupted: transition to completed. */
val startedInterrupted: Int = 3
/** The thread was started and execution has been interrupted: transition to completed. */
val startedInterrupted: ThreadStateInfo = ThreadStateInfo(3)

/** Execution was completed: terminal state. */
val completed: Int = 4
/** Execution has been completed: terminal state. */
val completed: ThreadStateInfo = ThreadStateInfo(4)
}

/** Represents the state of an execution thread. */
case class ThreadStateInfo(val transitionState: Int)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
Expand Down Expand Up @@ -160,19 +159,14 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}

private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]()
executions.forEach((_, executeHolder) => {
if (executeHolder.sessionHolder.key == key) {
sessionExecutionHolders += executeHolder
val info = executeHolder.getExecuteInfo
logInfo(
log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.")
removeExecuteHolder(executeHolder.key, abandoned = true)
}
})

sessionExecutionHolders.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
logInfo(
log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.")
removeExecuteHolder(executeHolder.key, abandoned = true)
}
}

/** Get info about abandoned execution, if there is one. */
Expand Down Expand Up @@ -252,30 +246,24 @@ private[connect] class SparkConnectExecutionManager() extends Logging {

// Visible for testing.
private[connect] def periodicMaintenance(timeout: Long): Unit = {
// Find any detached executions that expired and should be removed.
logInfo("Started periodic run of SparkConnectExecutionManager maintenance.")

// Find any detached executions that expired and should be removed.
val toRemove = new mutable.ArrayBuffer[ExecuteHolder]()
val nowMs = System.currentTimeMillis()

executions.forEach((_, executeHolder) => {
executeHolder.lastAttachedRpcTimeMs match {
case Some(detached) =>
if (detached + timeout <= nowMs) {
toRemove += executeHolder
val info = executeHolder.getExecuteInfo
logInfo(
log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " +
log"and expired and will be removed.")
removeExecuteHolder(executeHolder.key, abandoned = true)
}
case _ => // execution is active
}
})

// .. and remove them.
toRemove.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
logInfo(
log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " +
log"and expired and will be removed.")
removeExecuteHolder(executeHolder.key, abandoned = true)
}
logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
Expand Down Expand Up @@ -226,9 +225,8 @@ class SparkConnectSessionManager extends Logging {
private def periodicMaintenance(
defaultInactiveTimeoutMs: Long,
ignoreCustomTimeout: Boolean): Unit = {
logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
// Find any sessions that expired and should be removed.
val toRemove = new mutable.ArrayBuffer[SessionHolder]()
logInfo("Started periodic run of SparkConnectSessionManager maintenance.")

def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = {
val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && !ignoreCustomTimeout) {
Expand All @@ -242,15 +240,8 @@ class SparkConnectSessionManager extends Logging {

val nowMs = System.currentTimeMillis()
sessionStore.forEach((_, sessionHolder) => {
if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
toRemove += sessionHolder
}
})

// .. and remove them.
toRemove.foreach { sessionHolder =>
val info = sessionHolder.getSessionHolderInfo
if (shouldExpire(info, System.currentTimeMillis())) {
if (shouldExpire(info, nowMs)) {
logInfo(
log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
log"and will be closed.")
Expand All @@ -261,7 +252,8 @@ class SparkConnectSessionManager extends Logging {
case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
}
}
}
})

logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
}

Expand Down

0 comments on commit 03048c4

Please sign in to comment.