Skip to content

Commit

Permalink
Add Fixes for Post-Interrupt Hook Feature (#1234)
Browse files Browse the repository at this point in the history
* Add Fixes for Post-Interrupt Hook Feature

- This is a rework of #1186
  to make the post-interrupt hook feature work as intended

* Renaming afterInterruptHook -> postInterruptHook

* Remove setInstance / getInstanceOpt stuff

By passing the JupyterApi to Execute methods when needed

* Log from JupyterApiImpl

* Clean-up log messages

* NIT Format

* fixup

---------

Co-authored-by: Peter Christensen <[email protected]>
Co-authored-by: Alexandre Archambault <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2023
1 parent f210b50 commit ba66f15
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 23 deletions.
21 changes: 18 additions & 3 deletions docs/pages/api-jupyter.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,25 @@ lazy val spark = {
lazy val sc = {
spark.sparkContext
}
kernel.afterInterruptHooks += { _ =>
sc.cancelAllJobs()
}

// Add new hook with name "CancelAllSparkJobs"
kernel.addPostInterruptHook(
"CancelAllSparkJobs",
_ => sc.cancelAllJobs()
)

// Return a list with all registered hooks
kernel.postInterruptHooks

// Remove hook by name
kernel.removePostInterruptHook("CancelAllSparkJobs")

// Run after-interrupt hooks (called internally after a cell interrupt)
kernel.runPostInterruptHooks()
```
Since Scala anonymous functions don't print well after being compiled to bytecode
each hook is registered with a name.


### Hooks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ abstract class JupyterApi { api =>
def addExecuteHook(hook: JupyterApi.ExecuteHook): Boolean
def removeExecuteHook(hook: JupyterApi.ExecuteHook): Boolean

def addPostInterruptHook(name: String, hook: Any => Any): Boolean
def removePostInterruptHook(name: String): Boolean
def postInterruptHooks(): Seq[(String, Any => Any)]
def runPostInterruptHooks(): Unit

def consoleOut: PrintStream
def consoleErr: PrintStream
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ final class Execute(
case Some(p) => capture0(p.stdout, p.stderr)(t)
}

private def interruptible[T](t: => T): T = {
private def interruptible[T](jupyterApi: JupyterApi)(t: => T): T = {
interruptedStackTraceOpt0 = None
currentThreadOpt0 = Some(Thread.currentThread())
try
Expand All @@ -266,6 +266,9 @@ final class Execute(
log.debug(s"Calling 'Thread.stop'")
t.stop()
}

// Run post-interrupt hooks
jupyterApi.runPostInterruptHooks()
}
}.apply {
t
Expand All @@ -274,7 +277,7 @@ final class Execute(
currentThreadOpt0 = None
}

def interrupt(): Unit =
def interrupt(jupyterApi: JupyterApi): Unit =
currentThreadOpt0 match {
case None =>
log.warn("Interrupt asked, but no execution is running")
Expand All @@ -290,6 +293,9 @@ final class Execute(
log.debug(s"Calling 'Thread.stop'")
t.stop()
}

// Run post-interrupt hooks
jupyterApi.runPostInterruptHooks()
}

private var lastExceptionOpt0 = Option.empty[Throwable]
Expand All @@ -314,7 +320,8 @@ final class Execute(
code: String,
inputManager: Option[InputManager],
outputHandler: Option[OutputHandler],
storeHistory: Boolean
storeHistory: Boolean,
jupyterApi: JupyterApi
) =
withOutputHandler(outputHandler) {
val code0 = {
Expand All @@ -337,7 +344,7 @@ final class Execute(
Res.Failure(err)
}
_ = log.debug(s"splitted '$code0'")
ev <- interruptible {
ev <- interruptible(jupyterApi) {
withInputManager(inputManager) {
withClientStdin {
capturingOutput {
Expand Down Expand Up @@ -415,7 +422,8 @@ final class Execute(
outputHandler: Option[OutputHandler],
colors0: Ref[Colors],
storeHistory: Boolean,
executeHooks: Seq[JupyterApi.ExecuteHook]
executeHooks: Seq[JupyterApi.ExecuteHook],
jupyterApi: JupyterApi
): ExecuteResult = {

if (enableExitHack && code.endsWith("// ALMOND FORCE EXIT")) {
Expand All @@ -430,7 +438,7 @@ final class Execute(

val finalCodeOrResult =
withOutputHandler(outputHandler) {
interruptible {
interruptible(jupyterApi) {
withInputManager(inputManager, done = false) {
withClientStdin {
capturingOutput {
Expand Down Expand Up @@ -522,7 +530,8 @@ final class Execute(
finalCode,
inputManager,
outputHandler,
storeHistory
storeHistory,
jupyterApi
) match {
case Res.Success((_, data)) =>
ExecuteResult.Success(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import java.nio.charset.StandardCharsets
import almond.api.{FullJupyterApi, JupyterApi}
import almond.internals.HtmlAnsiOutputStream
import almond.interpreter.api.CommHandler
import almond.logger.LoggerContext
import ammonite.util.Ref
import pprint.{TPrint, TPrintColors}

import scala.collection.mutable
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal

/** Actual [[almond.api.JupyterApi]] instance */
final class JupyterApiImpl(
Expand All @@ -23,9 +25,12 @@ final class JupyterApiImpl(
protected val allowVariableInspector: Option[Boolean],
val kernelClassLoader: ClassLoader,
val consoleOut: PrintStream,
val consoleErr: PrintStream
val consoleErr: PrintStream,
logCtx: LoggerContext
) extends FullJupyterApi with VariableInspectorApiImpl {

private val log = logCtx(getClass)

protected def variableInspectorImplPPrinter() = replApi.pprinter()

protected def printOnChange[T](
Expand Down Expand Up @@ -90,5 +95,26 @@ final class JupyterApiImpl(
}
}

val afterInterruptHooks = mutable.Buffer.empty[Any => Any]
private val postInterruptHooks0 = new mutable.ListBuffer[(String, Any => Any)]
def addPostInterruptHook(name: String, hook: Any => Any): Boolean = {
!postInterruptHooks0.map(_._1).contains((name)) && {
postInterruptHooks0.append((name, hook))
true
}
}
def removePostInterruptHook(name: String): Boolean = {
val idx = postInterruptHooks0.map(_._1).indexOf(name)
idx >= 0 && {
postInterruptHooks0.remove(idx)
true
}
}
def postInterruptHooks(): Seq[(String, Any => Any)] = postInterruptHooks0.toList
def runPostInterruptHooks(): Unit =
try Function.chain(postInterruptHooks0.map(_._2)).apply(())
catch {
case NonFatal(e) =>
log.warn("Caught exception while running post-interrupt hooks", e)
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package almond

import almond.amm.AmmInterpreter
import almond.api.JupyterApi
import almond.internals._
import almond.interpreter._
import almond.interpreter.api.{CommHandler, OutputHandler}
Expand Down Expand Up @@ -100,7 +101,8 @@ final class ScalaInterpreter(
params.allowVariableInspector,
kernelClassLoader = Thread.currentThread().getContextClassLoader,
consoleOut = System.out,
consoleErr = System.err
consoleErr = System.err,
logCtx = logCtx
)

if (params.toreeMagics) {
Expand Down Expand Up @@ -150,15 +152,8 @@ final class ScalaInterpreter(

override def interruptSupported: Boolean =
true
override def interrupt(): Unit = {
execute0.interrupt()

try Function.chain(jupyterApi.afterInterruptHooks).apply(())
catch {
case NonFatal(e) =>
log.warn("Caught exception while trying to run after Interrupt hooks", e)
}
}
override def interrupt(): Unit =
execute0.interrupt(jupyterApi)

override def supportComm: Boolean = true
override def setCommHandler(commHandler0: CommHandler): Unit =
Expand All @@ -177,7 +172,8 @@ final class ScalaInterpreter(
outputHandler,
colors0,
storeHistory,
jupyterApi.executeHooks
jupyterApi.executeHooks,
jupyterApi
)

def currentLine(): Int =
Expand Down

0 comments on commit ba66f15

Please sign in to comment.