Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for rich display of Python objects in ScalaPy (rebased) #854

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ class AlmondSpark(val crossScalaVersion: String) extends AlmondModule with Mima
// sources.in(Compile, doc) := Nil
}

class AlmondScalaPy(val crossScalaVersion: String) extends AlmondModule with Mima {
def ivyDeps = Agg(
Deps.jvmRepr
)
def compileIvyDeps = Agg(
Deps.scalapy
)
}

class AlmondRx(val crossScalaVersion: String) extends AlmondModule with Mima {
def compileModuleDeps = Seq(
scala.`scala-kernel-api`()
Expand Down Expand Up @@ -290,6 +299,7 @@ object scala extends Module {
object `scala-interpreter` extends Cross[ScalaInterpreter](ScalaVersions.all: _*)
object `scala-kernel` extends Cross[ScalaKernel] (ScalaVersions.all: _*)
object `scala-kernel-helper` extends Cross[ScalaKernelHelper](ScalaVersions.all.filter(_.startsWith("3.")): _*)
object `almond-scalapy` extends Cross[AlmondScalaPy] (ScalaVersions.binaries: _*)
object `almond-spark` extends Cross[AlmondSpark] (ScalaVersions.scala212)
object `almond-rx` extends Cross[AlmondRx] (ScalaVersions.scala212)
}
Expand Down Expand Up @@ -480,12 +490,21 @@ def validateExamples(matcher: String = "") = {
Some(m)
}

val sv0 = {
val prefix = sv.split('.').take(2).map(_ + ".").mkString
ScalaVersions.binaries.find(_.startsWith(prefix)).getOrElse {
sys.error(s"Can't find a Scala version in ${ScalaVersions.binaries} with the same binary version as $sv (prefix: $prefix)")
}
}

T.command {
val launcher = scala.`scala-kernel`(sv).launcher().path
val jupyterPath = T.dest / "jupyter"
val outputDir = T.dest / "output"
os.makeDir.all(outputDir)

scala.`almond-scalapy`(sv0).publishLocalNoFluff((baseRepoRoot / "{VERSION}").toString)()

val version = scala.`scala-kernel`(sv).publishVersion()
val repoRoot = baseRepoRoot / version

Expand Down
248 changes: 248 additions & 0 deletions examples/scalapy-displays.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {

},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mai.kien.python.Python\n",
"\n",
"\u001b[39m"
]
},
"execution_count": 1,
"metadata": {

},
"output_type": "execute_result"
}
],
"source": [
"import $ivy.`ai.kien::python-native-libs:0.2.3`\n",
"import ai.kien.python.Python\n",
"\n",
"Python().scalapyProperties.fold(\n",
" ex => throw new Exception(ex),\n",
" props => props.map { kv => println(kv); kv }.foreach(Function.tupled(System.setProperty _))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {

},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mme.shadaj.scalapy.py\n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mme.shadaj.scalapy.py.PyQuote\n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mme.shadaj.scalapy.py.SeqConverters\u001b[39m"
]
},
"execution_count": 2,
"metadata": {

},
"output_type": "execute_result"
}
],
"source": [
"import $ivy.`me.shadaj::scalapy-core:0.5.2`\n",
"import me.shadaj.scalapy.py\n",
"import me.shadaj.scalapy.py.PyQuote\n",
"import me.shadaj.scalapy.py.SeqConverters"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {

},
"outputs": [

],
"source": [
"almond.scalapy.initDisplay"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {

},
"outputs": [

],
"source": [
"// disable pprint so that the next line won't show any output\n",
"repl.pprinter() = repl.pprinter().copy(defaultHeight = 0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {

},
"outputs": [
{
"data": {
"text/plain": [
"......"
]
},
"execution_count": 5,
"metadata": {

},
"output_type": "execute_result"
}
],
"source": [
"val display = py.module(\"IPython.display\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {

},
"outputs": [
{
"data": {
"text/html": [
"<b>hello</b>"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.HTML(\"<b>hello</b>\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {

},
"outputs": [
{
"data": {
"text/latex": [
"\\begin{eqnarray}\n",
"\\nabla \\times \\vec{\\mathbf{B}} -\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{E}}}{\\partial t} & = \\frac{4\\pi}{c}\\vec{\\mathbf{j}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{E}} & = 4 \\pi \\rho \\\\\n",
"\\nabla \\times \\vec{\\mathbf{E}}\\, +\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{B}}}{\\partial t} & = \\vec{\\mathbf{0}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{B}} & = 0 \n",
"\\end{eqnarray}"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.Latex(\"\"\"\\begin{eqnarray}\n",
"\\nabla \\times \\vec{\\mathbf{B}} -\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{E}}}{\\partial t} & = \\frac{4\\pi}{c}\\vec{\\mathbf{j}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{E}} & = 4 \\pi \\rho \\\\\n",
"\\nabla \\times \\vec{\\mathbf{E}}\\, +\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{B}}}{\\partial t} & = \\vec{\\mathbf{0}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{B}} & = 0 \n",
"\\end{eqnarray}\"\"\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {

},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a = b + c$"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.Math(\"a = b + c\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {

},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"# title\n",
"## subsec\n",
"foo\n"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.Markdown(\"\"\"\n",
"# title\n",
"## subsec\n",
"foo\n",
"\"\"\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Scala (sources)",
"language": "scala",
"name": "scala-debug"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".sc",
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json as __almond_scalapy_json


def __almond_scalapy_format_display_data(obj, include):
repr_methods = ((t, m) for m, t in include if m in set(dir(obj)))
representations = ((t, getattr(obj, m)()) for t, m in repr_methods)

display_data = (
(t, (r[0], r[1]) if isinstance(r, tuple) and len(r) == 2 else (r, None))
for t, r in representations if r is not None
)
display_data = [(t, m, md) for t, (m, md) in display_data if m is not None]

data = [
(t, d if isinstance(d, str) else __almond_scalapy_json.dumps(d))
for t, d, _ in display_data
]
metadata = [
(t, md if isinstance(md, str) else __almond_scalapy_json.dumps(md))
for t, _, md in display_data if md is not None
]

return data, metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package almond

import java.{util => ju}
import jupyter.{Displayer, Displayers}
import me.shadaj.scalapy.interpreter.CPythonInterpreter
import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.{PyQuote, SeqConverters}
import scala.io.Source
import scala.jdk.CollectionConverters._

package object scalapy {
CPythonInterpreter.execManyLines(Source.fromResource("format_display_data.py").mkString)

def initDisplay: Unit = {
Displayers.register(
classOf[py.Any],
new Displayer[py.Any] {
def display(obj: py.Any): ju.Map[String, String] = {
val (data, _) = formatDisplayData(obj)
if (data.isEmpty) null else data.asJava
}
}
)
}

private val pyFormatDisplayData = py.Dynamic.global.__almond_scalapy_format_display_data

private def formatDisplayData(obj: py.Any): (Map[String, String], Map[String, String]) = {
val displayData = pyFormatDisplayData(obj, allReprMethods.toPythonCopy)
val data = displayData.bracketAccess(0).as[List[(String, String)]].toMap
val metadata = displayData.bracketAccess(1).as[List[(String, String)]].toMap

(data, metadata)
}

private val mimetypes = Map(
"svg" -> "image/svg+xml",
"png" -> "image/png",
"jpeg" -> "image/jpeg",
"html" -> "text/html",
"javascript" -> "application/javascript",
"markdown" -> "text/markdown",
"latex" -> "text/latex"
)

private lazy val allReprMethods: Seq[(String, String)] =
mimetypes.map { case (k, v) => s"_repr_${k}_" -> v }.toSeq
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ final class ReplApiImpl(
.asInstanceOf[Displayer[T]]
.display(value)
.asScala
.toMap
p.display(DisplayData(m))
Some(Iterator())
if (m == null) None
else {
p.display(DisplayData(m.toMap))
Some(Iterator())
}
} else
for (updatableResults <- updatableResultsOpt if (onChange.nonEmpty && custom.isEmpty) || (onChangeOrError.nonEmpty && custom.nonEmpty)) yield {

Expand Down Expand Up @@ -204,4 +206,3 @@ final class ReplApiImpl(
object ReplApiImpl {
private class Foo
}

Loading