Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Support transformation of hijacked classes. #10

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cli/src/main/scala/TestSuites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ object TestSuites {
case class TestSuite(className: String, methodName: String)
val suites = List(
TestSuite("testsuite.core.simple.Simple", "simple"),
TestSuite("testsuite.core.add.Add", "add")
TestSuite("testsuite.core.add.Add", "add"),
TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"),
TestSuite("testsuite.core.hijackedclassesmono.HijackedClassesMonoTest", "hijackedClassesMono")
)
}
36 changes: 36 additions & 0 deletions test-suite/src/main/scala/testsuite/core/AsInstanceOfTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package testsuite.core.asinstanceof

import scala.scalajs.js.annotation._

object AsInstanceOfTest {
def main(): Unit = { val _ = test() }

@JSExportTopLevel("asInstanceOf")
def test(): Boolean = {
testInt(5) &&
testClasses(new Child()) &&
testString("foo", true)
}

def testClasses(c: Child): Boolean = {
val c1 = c.asInstanceOf[Child]
val c2 = c.asInstanceOf[Parent]
c1.foo() == 5 && c2.foo() == 5
}

def testInt(x: Int): Boolean = {
val x1 = x.asInstanceOf[Int]
x1 == 5
}

def testString(s: String, b: Boolean): Boolean = {
val s1 = s.asInstanceOf[String]
val s2 = ("" + b).asInstanceOf[String]
s1.length() == 3 && s2.length() == 4
}

class Parent {
def foo(): Int = 5
}
class Child extends Parent
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package testsuite.core.hijackedclassesmono

import scala.scalajs.js.annotation._

object HijackedClassesMonoTest {
def main(): Unit = { val _ = test() }

@JSExportTopLevel("hijackedClassesMono")
def test(): Boolean = {
testInteger(5) &&
testString("foo")
}

def testInteger(x: Int): Boolean = {
x.hashCode() == 5
}

def testString(foo: String): Boolean = {
foo.length() == 3 &&
foo.hashCode() == 101574
}
}
19 changes: 5 additions & 14 deletions wasm/src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ object Compiler {
} yield {
val onlyModule = moduleSet.modules.head

val filteredClasses = onlyModule.classDefs.filter { c =>
!ExcludedClasses.contains(c.className)
}
// Sort for stability
val sortedClasses = onlyModule.classDefs.sortBy(_.className)

filteredClasses.sortBy(_.className).foreach(showLinkedClass(_))
sortedClasses.foreach(showLinkedClass(_))

Preprocessor.preprocess(filteredClasses)(context)
Preprocessor.preprocess(sortedClasses)(context)
println("preprocessed")
filteredClasses.foreach { clazz =>
sortedClasses.foreach { clazz =>
builder.transformClassDef(clazz)
}
onlyModule.topLevelExports.foreach { tle =>
Expand All @@ -71,14 +70,6 @@ object Compiler {
}
}

private val ExcludedClasses: Set[ir.Names.ClassName] = {
import ir.Names._
HijackedClasses ++ // hijacked classes
Set(
ClassClass // java.lang.Class
)
}

private def showLinkedClass(clazz: LinkedClass): Unit = {
val writer = new java.io.PrintWriter(System.out)
val printer = new LinkedClassPrinter(writer)
Expand Down
9 changes: 9 additions & 0 deletions wasm/src/main/scala/ir2wasm/LibraryPatches.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ object LibraryPatches {

private val MethodPatches: Map[ClassName, List[MethodDef]] = {
Map(
ObjectClass -> List(
// TODO Remove this patch when we support getClass() and full string concatenation
MethodDef(
EMF, m("toString", Nil, T), NON,
Nil, ClassType(BoxedStringClass),
Some(StringLiteral("[object]"))
)(EOH, NOV)
),

BoxedCharacterClass.withSuffix("$") -> List(
MethodDef(
EMF, m("toString", List(C), T), NON,
Expand Down
22 changes: 9 additions & 13 deletions wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,24 @@ object Preprocessor {
for (clazz <- classes)
preprocess(clazz)

for (clazz <- classes) {
if (clazz.className != IRNames.ObjectClass)
collectAbstractMethodCalls(clazz)
}
for (clazz <- classes)
collectAbstractMethodCalls(clazz)
}

private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
clazz.kind match {
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface =>
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface | ClassKind.HijackedClass =>
collectMethods(clazz)
case ClassKind.JSClass | ClassKind.JSModuleClass | ClassKind.NativeJSModuleClass |
ClassKind.AbstractJSType | ClassKind.NativeJSClass | ClassKind.HijackedClass =>
ClassKind.AbstractJSType | ClassKind.NativeJSClass =>
???
}
}

private def collectMethods(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
val infos =
if (clazz.name.name == IRNames.ObjectClass) Nil
else
clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
makeWasmFunctionInfo(clazz, method)
}
val infos = clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
makeWasmFunctionInfo(clazz, method)
}
ctx.putClassInfo(
clazz.name.name,
new WasmClassInfo(
Expand All @@ -48,7 +43,8 @@ object Preprocessor {
infos,
clazz.fields.collect { case f: IRTrees.FieldDef => Names.WasmFieldName(f.name.name) },
clazz.superClass.map(_.name),
clazz.interfaces.map(_.name)
clazz.interfaces.map(_.name),
clazz.ancestors
)
)
}
Expand Down
32 changes: 18 additions & 14 deletions wasm/src/main/scala/ir2wasm/WasmBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ class WasmBuilder {

def transformClassDef(clazz: LinkedClass)(implicit ctx: WasmContext) = {
clazz.kind match {
case ClassKind.ModuleClass => transformModuleClass(clazz)
case ClassKind.Class => transformClass(clazz)
case ClassKind.Interface => transformInterface(clazz)
case _ =>
case ClassKind.ModuleClass => transformModuleClass(clazz)
case ClassKind.Class => transformClass(clazz)
case ClassKind.HijackedClass => transformHijackedClass(clazz)
case ClassKind.Interface => transformInterface(clazz)
case _ => ???
}
}

Expand Down Expand Up @@ -68,15 +69,10 @@ class WasmBuilder {
)
ctx.addGCType(structType)

// Do not generate methods in Object for now
if (clazz.name.name == IRNames.ObjectClass)
clazz.methods.filter(_.name.name == IRNames.NoArgConstructorName).foreach { method =>
genFunction(clazz, method)
}
else
clazz.methods.foreach { method =>
genFunction(clazz, method)
}
// implementation of methods
clazz.methods.foreach { method =>
genFunction(clazz, method)
}

structType
}
Expand Down Expand Up @@ -242,6 +238,12 @@ class WasmBuilder {
transformClassCommon(clazz)
}

private def transformHijackedClass(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
clazz.methods.foreach { method =>
genFunction(clazz, method)
}
}

private def transformInterface(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
assert(clazz.kind == ClassKind.Interface)
// gen itable type
Expand Down Expand Up @@ -389,7 +391,9 @@ class WasmBuilder {
// Receiver type for non-constructor methods needs to be Object type because params are invariant
// Otherwise, vtable can't be a subtype of the supertype's subtype
// Constructor can use the exact type because it won't be registered to vtables.
if (method.flags.namespace.isConstructor)
if (clazz.kind == ClassKind.HijackedClass)
transformType(IRTypes.BoxedClassToPrimType(clazz.name.name))
else if (method.flags.namespace.isConstructor)
WasmRefNullType(WasmHeapType.Type(WasmTypeName.WasmStructTypeName(clazz.name.name)))
else
WasmRefNullType(WasmHeapType.ObjectType),
Expand Down
50 changes: 45 additions & 5 deletions wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
transformApplyStatically(t)
case t: IRTrees.Apply => transformApply(t)
case t: IRTrees.ApplyDynamicImport => ???
case t: IRTrees.AsInstanceOf => transformAsInstanceOf(t)
case t: IRTrees.Block => transformBlock(t)
case t: IRTrees.Labeled => transformLabeled(t)
case t: IRTrees.Return => transformReturn(t)
Expand Down Expand Up @@ -180,8 +181,9 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
val wasmArgs = t.args.flatMap(transformTree)

val receiverClassName = t.receiver.tpe match {
case ClassType(className) => className
case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}")
case ClassType(className) => className
case prim: IRTypes.PrimType => IRTypes.PrimTypeToBoxedClass(prim)
case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}")
}
val receiverClassInfo = ctx.getClassInfo(receiverClassName)

Expand Down Expand Up @@ -228,6 +230,18 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
TypeIdx(method.toWasmFunctionType()(ctx).name)
)
)
} else if (receiverClassInfo.kind == ClassKind.HijackedClass) {
// statically resolved call
Comment on lines +233 to +234
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

val info = receiverClassInfo.getMethodInfo(t.method.name)
val castIfNeeded =
if (receiverClassName == IRNames.BoxedStringClass && t.receiver.tpe == ClassType(IRNames.BoxedStringClass))
List(REF_CAST(HeapType(Types.WasmHeapType.Type(WasmStructTypeName.string))))
else
Nil
pushReceiver ++ castIfNeeded ++ wasmArgs ++
List(
CALL(FuncIdx(info.name))
)
} else { // virtual dispatch
val (methodIdx, info) = ctx
.calculateVtable(receiverClassName)
Expand Down Expand Up @@ -401,6 +415,17 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
case BinaryOp.Long_>>> => longShiftOp(I64_SHR_U)
case BinaryOp.Long_>> => longShiftOp(I64_SHR_S)

// New in 1.11
case BinaryOp.String_charAt =>
transformTree(binary.lhs) ++ // push the string
List(
STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array
) ++
transformTree(binary.rhs) ++ // push the index
List(
ARRAY_GET_U(TypeIdx(WasmArrayTypeName.stringData)) // access the element of the array
)

case _ => transformElementaryBinaryOp(binary)
}
}
Expand Down Expand Up @@ -479,9 +504,6 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
case BinaryOp.Double_<= => F64_LE
case BinaryOp.Double_> => F64_GT
case BinaryOp.Double_>= => F64_GE

// // New in 1.11
case BinaryOp.String_charAt => ??? // TODO
}
lhsInstrs ++ rhsInstrs :+ operation
}
Expand Down Expand Up @@ -539,6 +561,24 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
}
}

private def transformAsInstanceOf(tree: IRTrees.AsInstanceOf): List[WasmInstr] = {
val exprInstrs = transformTree(tree.expr)

val sourceTpe = tree.expr.tpe
val targetTpe = tree.tpe

if (IRTypes.isSubtype(sourceTpe, targetTpe)(isSubclass(_, _))) {
// Common case where no cast is necessary
exprInstrs
} else {
println(tree)
???
}
}

private def isSubclass(subClass: IRNames.ClassName, superClass: IRNames.ClassName): Boolean =
ctx.getClassInfo(subClass).ancestors.contains(superClass)

private def transformVarRef(r: IRTrees.VarRef): LOCAL_GET = {
val name = WasmLocalName.fromIR(r.ident.name)
LOCAL_GET(LocalIdx(name))
Expand Down
9 changes: 8 additions & 1 deletion wasm/src/main/scala/wasm4s/WasmContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ object WasmContext {
private var _methods: List[WasmFunctionInfo],
private val fields: List[WasmFieldName],
val superClass: Option[IRNames.ClassName],
val interfaces: List[IRNames.ClassName]
val interfaces: List[IRNames.ClassName],
val ancestors: List[IRNames.ClassName]
) {

def isInterface = kind == ClassKind.Interface
Expand All @@ -145,6 +146,12 @@ object WasmContext {
}
}

def getMethodInfo(methodName: IRNames.MethodName): WasmFunctionInfo = {
methods.find(_.name.methodName == methodName.nameString).getOrElse {
throw new IllegalArgumentException(s"Cannot find method ${methodName.nameString} in class ${name.nameString}")
}
}

def getFieldIdx(name: WasmFieldName): WasmImmediate.StructFieldIdx =
fields.indexWhere(_ == name) match {
case i if i < 0 => throw new Error(s"Field not found: $name")
Expand Down