Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #10 from sjrd/enable-hijacked-classes
Browse files Browse the repository at this point in the history
Support transformation of hijacked classes.
  • Loading branch information
tanishiking authored Mar 8, 2024
2 parents 4c6609b + 45a43a9 commit b3701b9
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 48 deletions.
4 changes: 3 additions & 1 deletion cli/src/main/scala/TestSuites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ object TestSuites {
case class TestSuite(className: String, methodName: String)
val suites = List(
TestSuite("testsuite.core.simple.Simple", "simple"),
TestSuite("testsuite.core.add.Add", "add")
TestSuite("testsuite.core.add.Add", "add"),
TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"),
TestSuite("testsuite.core.hijackedclassesmono.HijackedClassesMonoTest", "hijackedClassesMono")
)
}
36 changes: 36 additions & 0 deletions test-suite/src/main/scala/testsuite/core/AsInstanceOfTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package testsuite.core.asinstanceof

import scala.scalajs.js.annotation._

object AsInstanceOfTest {
def main(): Unit = { val _ = test() }

@JSExportTopLevel("asInstanceOf")
def test(): Boolean = {
testInt(5) &&
testClasses(new Child()) &&
testString("foo", true)
}

def testClasses(c: Child): Boolean = {
val c1 = c.asInstanceOf[Child]
val c2 = c.asInstanceOf[Parent]
c1.foo() == 5 && c2.foo() == 5
}

def testInt(x: Int): Boolean = {
val x1 = x.asInstanceOf[Int]
x1 == 5
}

def testString(s: String, b: Boolean): Boolean = {
val s1 = s.asInstanceOf[String]
val s2 = ("" + b).asInstanceOf[String]
s1.length() == 3 && s2.length() == 4
}

class Parent {
def foo(): Int = 5
}
class Child extends Parent
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package testsuite.core.hijackedclassesmono

import scala.scalajs.js.annotation._

object HijackedClassesMonoTest {
def main(): Unit = { val _ = test() }

@JSExportTopLevel("hijackedClassesMono")
def test(): Boolean = {
testInteger(5) &&
testString("foo")
}

def testInteger(x: Int): Boolean = {
x.hashCode() == 5
}

def testString(foo: String): Boolean = {
foo.length() == 3 &&
foo.hashCode() == 101574
}
}
19 changes: 5 additions & 14 deletions wasm/src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ object Compiler {
} yield {
val onlyModule = moduleSet.modules.head

val filteredClasses = onlyModule.classDefs.filter { c =>
!ExcludedClasses.contains(c.className)
}
// Sort for stability
val sortedClasses = onlyModule.classDefs.sortBy(_.className)

filteredClasses.sortBy(_.className).foreach(showLinkedClass(_))
sortedClasses.foreach(showLinkedClass(_))

Preprocessor.preprocess(filteredClasses)(context)
Preprocessor.preprocess(sortedClasses)(context)
println("preprocessed")
filteredClasses.foreach { clazz =>
sortedClasses.foreach { clazz =>
builder.transformClassDef(clazz)
}
onlyModule.topLevelExports.foreach { tle =>
Expand All @@ -71,14 +70,6 @@ object Compiler {
}
}

private val ExcludedClasses: Set[ir.Names.ClassName] = {
import ir.Names._
HijackedClasses ++ // hijacked classes
Set(
ClassClass // java.lang.Class
)
}

private def showLinkedClass(clazz: LinkedClass): Unit = {
val writer = new java.io.PrintWriter(System.out)
val printer = new LinkedClassPrinter(writer)
Expand Down
9 changes: 9 additions & 0 deletions wasm/src/main/scala/ir2wasm/LibraryPatches.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ object LibraryPatches {

private val MethodPatches: Map[ClassName, List[MethodDef]] = {
Map(
ObjectClass -> List(
// TODO Remove this patch when we support getClass() and full string concatenation
MethodDef(
EMF, m("toString", Nil, T), NON,
Nil, ClassType(BoxedStringClass),
Some(StringLiteral("[object]"))
)(EOH, NOV)
),

BoxedCharacterClass.withSuffix("$") -> List(
MethodDef(
EMF, m("toString", List(C), T), NON,
Expand Down
22 changes: 9 additions & 13 deletions wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,24 @@ object Preprocessor {
for (clazz <- classes)
preprocess(clazz)

for (clazz <- classes) {
if (clazz.className != IRNames.ObjectClass)
collectAbstractMethodCalls(clazz)
}
for (clazz <- classes)
collectAbstractMethodCalls(clazz)
}

private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
clazz.kind match {
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface =>
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface | ClassKind.HijackedClass =>
collectMethods(clazz)
case ClassKind.JSClass | ClassKind.JSModuleClass | ClassKind.NativeJSModuleClass |
ClassKind.AbstractJSType | ClassKind.NativeJSClass | ClassKind.HijackedClass =>
ClassKind.AbstractJSType | ClassKind.NativeJSClass =>
???
}
}

private def collectMethods(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
val infos =
if (clazz.name.name == IRNames.ObjectClass) Nil
else
clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
makeWasmFunctionInfo(clazz, method)
}
val infos = clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
makeWasmFunctionInfo(clazz, method)
}
ctx.putClassInfo(
clazz.name.name,
new WasmClassInfo(
Expand All @@ -48,7 +43,8 @@ object Preprocessor {
infos,
clazz.fields.collect { case f: IRTrees.FieldDef => Names.WasmFieldName(f.name.name) },
clazz.superClass.map(_.name),
clazz.interfaces.map(_.name)
clazz.interfaces.map(_.name),
clazz.ancestors
)
)
}
Expand Down
32 changes: 18 additions & 14 deletions wasm/src/main/scala/ir2wasm/WasmBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ class WasmBuilder {

def transformClassDef(clazz: LinkedClass)(implicit ctx: WasmContext) = {
clazz.kind match {
case ClassKind.ModuleClass => transformModuleClass(clazz)
case ClassKind.Class => transformClass(clazz)
case ClassKind.Interface => transformInterface(clazz)
case _ =>
case ClassKind.ModuleClass => transformModuleClass(clazz)
case ClassKind.Class => transformClass(clazz)
case ClassKind.HijackedClass => transformHijackedClass(clazz)
case ClassKind.Interface => transformInterface(clazz)
case _ => ???
}
}

Expand Down Expand Up @@ -68,15 +69,10 @@ class WasmBuilder {
)
ctx.addGCType(structType)

// Do not generate methods in Object for now
if (clazz.name.name == IRNames.ObjectClass)
clazz.methods.filter(_.name.name == IRNames.NoArgConstructorName).foreach { method =>
genFunction(clazz, method)
}
else
clazz.methods.foreach { method =>
genFunction(clazz, method)
}
// implementation of methods
clazz.methods.foreach { method =>
genFunction(clazz, method)
}

structType
}
Expand Down Expand Up @@ -242,6 +238,12 @@ class WasmBuilder {
transformClassCommon(clazz)
}

private def transformHijackedClass(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
clazz.methods.foreach { method =>
genFunction(clazz, method)
}
}

private def transformInterface(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
assert(clazz.kind == ClassKind.Interface)
// gen itable type
Expand Down Expand Up @@ -389,7 +391,9 @@ class WasmBuilder {
// Receiver type for non-constructor methods needs to be Object type because params are invariant
// Otherwise, vtable can't be a subtype of the supertype's subtype
// Constructor can use the exact type because it won't be registered to vtables.
if (method.flags.namespace.isConstructor)
if (clazz.kind == ClassKind.HijackedClass)
transformType(IRTypes.BoxedClassToPrimType(clazz.name.name))
else if (method.flags.namespace.isConstructor)
WasmRefNullType(WasmHeapType.Type(WasmTypeName.WasmStructTypeName(clazz.name.name)))
else
WasmRefNullType(WasmHeapType.ObjectType),
Expand Down
50 changes: 45 additions & 5 deletions wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
transformApplyStatically(t)
case t: IRTrees.Apply => transformApply(t)
case t: IRTrees.ApplyDynamicImport => ???
case t: IRTrees.AsInstanceOf => transformAsInstanceOf(t)
case t: IRTrees.Block => transformBlock(t)
case t: IRTrees.Labeled => transformLabeled(t)
case t: IRTrees.Return => transformReturn(t)
Expand Down Expand Up @@ -180,8 +181,9 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
val wasmArgs = t.args.flatMap(transformTree)

val receiverClassName = t.receiver.tpe match {
case ClassType(className) => className
case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}")
case ClassType(className) => className
case prim: IRTypes.PrimType => IRTypes.PrimTypeToBoxedClass(prim)
case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}")
}
val receiverClassInfo = ctx.getClassInfo(receiverClassName)

Expand Down Expand Up @@ -228,6 +230,18 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
TypeIdx(method.toWasmFunctionType()(ctx).name)
)
)
} else if (receiverClassInfo.kind == ClassKind.HijackedClass) {
// statically resolved call
val info = receiverClassInfo.getMethodInfo(t.method.name)
val castIfNeeded =
if (receiverClassName == IRNames.BoxedStringClass && t.receiver.tpe == ClassType(IRNames.BoxedStringClass))
List(REF_CAST(HeapType(Types.WasmHeapType.Type(WasmStructTypeName.string))))
else
Nil
pushReceiver ++ castIfNeeded ++ wasmArgs ++
List(
CALL(FuncIdx(info.name))
)
} else { // virtual dispatch
val (methodIdx, info) = ctx
.calculateVtable(receiverClassName)
Expand Down Expand Up @@ -401,6 +415,17 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
case BinaryOp.Long_>>> => longShiftOp(I64_SHR_U)
case BinaryOp.Long_>> => longShiftOp(I64_SHR_S)

// New in 1.11
case BinaryOp.String_charAt =>
transformTree(binary.lhs) ++ // push the string
List(
STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array
) ++
transformTree(binary.rhs) ++ // push the index
List(
ARRAY_GET_U(TypeIdx(WasmArrayTypeName.stringData)) // access the element of the array
)

case _ => transformElementaryBinaryOp(binary)
}
}
Expand Down Expand Up @@ -479,9 +504,6 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
case BinaryOp.Double_<= => F64_LE
case BinaryOp.Double_> => F64_GT
case BinaryOp.Double_>= => F64_GE

// // New in 1.11
case BinaryOp.String_charAt => ??? // TODO
}
lhsInstrs ++ rhsInstrs :+ operation
}
Expand Down Expand Up @@ -539,6 +561,24 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
}
}

private def transformAsInstanceOf(tree: IRTrees.AsInstanceOf): List[WasmInstr] = {
val exprInstrs = transformTree(tree.expr)

val sourceTpe = tree.expr.tpe
val targetTpe = tree.tpe

if (IRTypes.isSubtype(sourceTpe, targetTpe)(isSubclass(_, _))) {
// Common case where no cast is necessary
exprInstrs
} else {
println(tree)
???
}
}

private def isSubclass(subClass: IRNames.ClassName, superClass: IRNames.ClassName): Boolean =
ctx.getClassInfo(subClass).ancestors.contains(superClass)

private def transformVarRef(r: IRTrees.VarRef): LOCAL_GET = {
val name = WasmLocalName.fromIR(r.ident.name)
LOCAL_GET(LocalIdx(name))
Expand Down
9 changes: 8 additions & 1 deletion wasm/src/main/scala/wasm4s/WasmContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ object WasmContext {
private var _methods: List[WasmFunctionInfo],
private val fields: List[WasmFieldName],
val superClass: Option[IRNames.ClassName],
val interfaces: List[IRNames.ClassName]
val interfaces: List[IRNames.ClassName],
val ancestors: List[IRNames.ClassName]
) {

def isInterface = kind == ClassKind.Interface
Expand All @@ -145,6 +146,12 @@ object WasmContext {
}
}

def getMethodInfo(methodName: IRNames.MethodName): WasmFunctionInfo = {
methods.find(_.name.methodName == methodName.nameString).getOrElse {
throw new IllegalArgumentException(s"Cannot find method ${methodName.nameString} in class ${name.nameString}")
}
}

def getFieldIdx(name: WasmFieldName): WasmImmediate.StructFieldIdx =
fields.indexWhere(_ == name) match {
case i if i < 0 => throw new Error(s"Field not found: $name")
Expand Down

0 comments on commit b3701b9

Please sign in to comment.