Skip to content

Commit

Permalink
Allow users to add custom Toree magic handlers (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarchambault authored Apr 19, 2023
1 parent 6773913 commit 3310b95
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,63 @@ object ScalaKernelTests extends TestSuite {
.mkString("Vector(" + "\n", "", "...")
)
}

test("toree custom cell magic") {

val predef =
"""almond.toree.CellMagicHook.addHandler("test") { (_, content) =>
| import almond.api.JupyterApi
| import almond.interpreter.api.DisplayData
|
| Left(JupyterApi.ExecuteHookResult.Success(DisplayData.text(content)))
|}
|
|almond.toree.CellMagicHook.addHandler("thing") { (_, content) =>
| import almond.api.JupyterApi
| import almond.interpreter.api.DisplayData
|
| val nl = System.lineSeparator()
| Right(s"val thing = {" + nl + content + nl + "}" + nl)
|}
|""".stripMargin

val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite,
predefCode = predef,
toreeMagics = true
),
logCtx = logCtx
)

val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
.unsafeRunTimedOrThrow()

implicit val sessionId: SessionId = SessionId()

kernel.execute(
"""%%test
|foo
|a
|""".stripMargin,
"""foo
|a
|""".stripMargin
)

val nl = System.lineSeparator()

kernel.execute(
"""%%thing
|println("Hello")
|2
|""".stripMargin,
"thing: Int = 2",
stdout =
"Hello" + nl +
"thing: Int = 2"
)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package almond.toree

import almond.api.JupyterApi

@FunctionalInterface
trait CellMagicHandler {
def handle(name: String, content: String): Either[JupyterApi.ExecuteHookResult, String]
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@ import almond.interpreter.api.OutputHandler

import java.util.Locale

import scala.collection.mutable

object CellMagicHook {

private var userHandlers = new mutable.HashMap[String, CellMagicHandler]

def addHandler(name: String)(handler: CellMagicHandler): Unit =
userHandlers += name -> handler

def clearHandlers(): Unit =
userHandlers.clear()

def hook(publish: OutputHandler): JupyterApi.ExecuteHook = {
val handlers = CellMagicHandlers.handlers(publish)
code =>
Expand All @@ -16,7 +26,8 @@ object CellMagicHook {
}
nameOpt match {
case Some(name) =>
handlers.get(name.toLowerCase(Locale.ROOT)) match {
val name0 = name.toLowerCase(Locale.ROOT)
userHandlers.get(name0).orElse(handlers.get(name0)) match {
case Some(handler) =>
val content = code.linesWithSeparators.drop(1).mkString
handler.handle(name, content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package almond.toree

import almond.api.JupyterApi

@FunctionalInterface
trait LineMagicHandler {
def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String]
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ import scala.collection.mutable

object LineMagicHook {

private var userHandlers = new mutable.HashMap[String, LineMagicHandler]

def addHandler(name: String)(handler: LineMagicHandler): Unit =
userHandlers += name -> handler

def clearHandlers(): Unit =
userHandlers.clear()

private val sep = Pattern.compile("\\s+")

def inspect(code: String): Iterator[Either[(Seq[String], String, String), String]] = {
Expand Down Expand Up @@ -47,7 +55,9 @@ object LineMagicHook {

assert(name.startsWith("%"))

handlers.get(name.toLowerCase(Locale.ROOT).stripPrefix("%")) match {
val name0 = name.toLowerCase(Locale.ROOT).stripPrefix("%")

userHandlers.get(name0).orElse(handlers.get(name0)) match {
case None =>
System.err.println(s"Warning: ignoring unrecognized Toree line magic $name")
case Some(handler) =>
Expand Down

0 comments on commit 3310b95

Please sign in to comment.