Skip to content

Commit

Permalink
Try to support some wasm script grammar, for testing (#61)
Browse files Browse the repository at this point in the history
* try support some webassembly script grammar

* a tiny example of wast

* fix test

* add wast test to ci

* cosmetic stuff

* remove TestSyntax.scala

---------

Co-authored-by: Guannan Wei <[email protected]>
Co-authored-by: ahuoguo <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent 28b6ad0 commit b2afc58
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@ jobs:
sbt 'testOnly gensym.TestImpCPSGS_Z3'
sbt 'testOnly gensym.TestLibrary'
sbt 'testOnly gensym.wasm.TestEval'
sbt 'testOnly gensym.wasm.TestScriptRun'
8 changes: 8 additions & 0 deletions benchmarks/wasm/script/script_basic.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
(module
(func $one (result i32)
i32.const 1)
(export "one" (func 0))
)

(assert_return (invoke "one") (i32.const 1))

14 changes: 14 additions & 0 deletions src/main/scala/wasm/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,17 @@ case class ExportFunc(i: Int) extends ExportDesc
case class ExportTable(i: Int) extends ExportDesc
case class ExportMemory(i: Int) extends ExportDesc
case class ExportGlobal(i: Int) extends ExportDesc

case class Script(cmds: List[Cmd]) extends WIR
abstract class Cmd extends WIR
// TODO: can we turn abstract class sealed?
case class CmdModule(module: Module) extends Cmd

abstract class Action extends WIR
case class Invoke(instName: Option[String], name: String, args: List[Value]) extends Action

abstract class Assertion extends Cmd
case class AssertReturn(action: Action, expect: List[Num] /* TODO: support multiple expect result type*/)
extends Assertion
case class AssertTrap(action: Action, message: String) extends Assertion

66 changes: 39 additions & 27 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,44 @@ case class ModuleInstance(
exports: List[Export] = List()
)

object ModuleInstance {
def apply(module: Module): ModuleInstance = {
val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non-exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val exports = module.definitions
.collect({
case e @ Export(_, ExportFunc(_)) => e
})
.toList

ModuleInstance(types, module.funcEnv, memory, globals, exports)
}
}

object Primtives {
def evalBinOp(op: BinOp, lhs: Value, rhs: Value): Value = op match {
case Add(_) =>
Expand Down Expand Up @@ -412,33 +450,7 @@ object Evaluator {

if (instrs.isEmpty) println("Warning: nothing is executed")

val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non-exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)
val moduleInst = ModuleInstance(module)

Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt))
}
Expand Down
49 changes: 49 additions & 0 deletions src/main/scala/wasm/MiniWasmScript.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package gensym.wasm.miniwasmscript

import gensym.wasm.miniwasm._
import gensym.wasm.ast._
import scala.collection.mutable.{ListBuffer, Map, ArrayBuffer}

sealed class ScriptRunner {
val instances: ListBuffer[ModuleInstance] = ListBuffer()
val instanceMap: Map[String, ModuleInstance] = Map()

def getInstance(instName: Option[String]): ModuleInstance = {
instName match {
case Some(name) => instanceMap(name)
case None => instances.head
}
}

def assertReturn(action: Action, expect: List[Value]): Unit = {
action match {
case Invoke(instName, name, args) =>
val module = getInstance(instName)
val func = module.exports.collectFirst({
case Export(`name`, ExportFunc(index)) =>
module.funcs(index)
case _ => throw new RuntimeException("Not Supported")
}).get
val instrs = func match {
case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => body
}
val k = (retStack: List[Value]) => retStack
val actual = Evaluator.eval(instrs, List(), Frame(module, ArrayBuffer(args: _*)), k, List(k))
assert(actual == expect)
}
}

def runCmd(cmd: Cmd): Unit = {
cmd match {
case CmdModule(module) => instances += ModuleInstance(module)
case AssertReturn(action, expect) => assertReturn(action, expect)
case AssertTrap(action, message) => ???
}
}

def run(script: Script): Unit = {
for (cmd <- script.cmds) {
runCmd(cmd)
}
}
}
76 changes: 74 additions & 2 deletions src/main/scala/wasm/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,59 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] {
else if (ctx.MEMORY != null) ExportMemory(id)
else if (ctx.GLOBAL != null) ExportGlobal(id)
else error
}

override def visitScriptModule(ctx: ScriptModuleContext): Module = {
if (ctx.module_ != null) {
visitModule_(ctx.module_).asInstanceOf[Module]
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitAction_(ctx: Action_Context): Action = {
if (ctx.INVOKE != null) {
val instName = if (ctx.VAR != null) Some(ctx.VAR().getText) else None
var name = ctx.name.getText.substring(1).dropRight(1)
var args = for (constCtx <- ctx.constList.wconst.asScala) yield {
val Array(ty, _) = constCtx.CONST.getText.split("\\.")
visitLiteralWithType(constCtx.literal, toNumType(ty))
}
Invoke(instName, name, args.toList)
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitAssertion(ctx: AssertionContext): Assertion = {
if (ctx.ASSERT_RETURN != null) {
val action = visitAction_(ctx.action_)
val expect = for (constCtx <- ctx.constList.wconst.asScala) yield {
val Array(ty, _) = constCtx.CONST.getText.split("\\.")
visitLiteralWithType(constCtx.literal, toNumType(ty))
}
println(s"expect = $expect")
AssertReturn(action, expect.toList)
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitCmd(ctx: CmdContext): Cmd = {
if (ctx.assertion != null) {
visitAssertion(ctx.assertion)
} else if (ctx.scriptModule != null) {
CmdModule(visitScriptModule(ctx.scriptModule))
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitScript(ctx: ScriptContext): WIR = {
val cmds = for (cmd <- ctx.cmd.asScala) yield {
visitCmd(cmd)
}
Script(cmds.toList)
}

override def visitTag(ctx: TagContext): WIR = {
Expand All @@ -645,15 +697,35 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] {
}

object Parser {
def parse(input: String): Module = {
private def makeWatVisitor(input: String) = {
val charStream = new ANTLRInputStream(input)
val lexer = new WatLexer(charStream)
val tokens = new CommonTokenStream(lexer)
val parser = new WatParser(tokens)
new WatParser(tokens)
}

def parse(input: String): Module = {
val parser = makeWatVisitor(input)
val visitor = new GSWasmVisitor()
val res: Module = visitor.visit(parser.module).asInstanceOf[Module]
res
}

def parseFile(filepath: String): Module = parse(scala.io.Source.fromFile(filepath).mkString)

// parse extended webassembly script language
def parseScript(input: String): Option[Script] = {
val parser = makeWatVisitor(input)
val visitor = new GSWasmVisitor()
val tree = parser.script()
val errorNumer = parser.getNumberOfSyntaxErrors()
if (errorNumer != 0) None
else {
val res: Script = visitor.visitScript(tree).asInstanceOf[Script]
Some(res)
}
}

def parseScriptFile(filepath: String): Option[Script] =
parseScript(scala.io.Source.fromFile(filepath).mkString)
}
19 changes: 19 additions & 0 deletions src/test/scala/genwasym/TestScriptRun.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package gensym.wasm

import gensym.wasm.parser.Parser
import gensym.wasm.miniwasmscript.ScriptRunner

import org.scalatest.FunSuite


class TestScriptRun extends FunSuite {
def testFile(filename: String): Unit = {
val script = Parser.parseScriptFile(filename).get
val runner = new ScriptRunner()
runner.run(script)
}

test("simple script") {
testFile("./benchmarks/wasm/script/script_basic.wast")
}
}

0 comments on commit b2afc58

Please sign in to comment.