diff --git a/cli/src/main/scala/TestSuites.scala b/cli/src/main/scala/TestSuites.scala index d737e910..81d80194 100644 --- a/cli/src/main/scala/TestSuites.scala +++ b/cli/src/main/scala/TestSuites.scala @@ -4,6 +4,8 @@ object TestSuites { case class TestSuite(className: String, methodName: String) val suites = List( TestSuite("testsuite.core.simple.Simple", "simple"), - TestSuite("testsuite.core.add.Add", "add") + TestSuite("testsuite.core.add.Add", "add"), + TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"), + TestSuite("testsuite.core.hijackedclassesmono.HijackedClassesMonoTest", "hijackedClassesMono") ) } diff --git a/test-suite/src/main/scala/testsuite/core/AsInstanceOfTest.scala b/test-suite/src/main/scala/testsuite/core/AsInstanceOfTest.scala new file mode 100644 index 00000000..a3f0c06e --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/AsInstanceOfTest.scala @@ -0,0 +1,36 @@ +package testsuite.core.asinstanceof + +import scala.scalajs.js.annotation._ + +object AsInstanceOfTest { + def main(): Unit = { val _ = test() } + + @JSExportTopLevel("asInstanceOf") + def test(): Boolean = { + testInt(5) && + testClasses(new Child()) && + testString("foo", true) + } + + def testClasses(c: Child): Boolean = { + val c1 = c.asInstanceOf[Child] + val c2 = c.asInstanceOf[Parent] + c1.foo() == 5 && c2.foo() == 5 + } + + def testInt(x: Int): Boolean = { + val x1 = x.asInstanceOf[Int] + x1 == 5 + } + + def testString(s: String, b: Boolean): Boolean = { + val s1 = s.asInstanceOf[String] + val s2 = ("" + b).asInstanceOf[String] + s1.length() == 3 && s2.length() == 4 + } + + class Parent { + def foo(): Int = 5 + } + class Child extends Parent +} diff --git a/test-suite/src/main/scala/testsuite/core/HijackedClassesMonoTest.scala b/test-suite/src/main/scala/testsuite/core/HijackedClassesMonoTest.scala new file mode 100644 index 00000000..45f2106b --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/HijackedClassesMonoTest.scala @@ -0,0 +1,22 @@ +package testsuite.core.hijackedclassesmono + +import scala.scalajs.js.annotation._ + +object HijackedClassesMonoTest { + def main(): Unit = { val _ = test() } + + @JSExportTopLevel("hijackedClassesMono") + def test(): Boolean = { + testInteger(5) && + testString("foo") + } + + def testInteger(x: Int): Boolean = { + x.hashCode() == 5 + } + + def testString(foo: String): Boolean = { + foo.length() == 3 && + foo.hashCode() == 101574 + } +} diff --git a/wasm/src/main/scala/Compiler.scala b/wasm/src/main/scala/Compiler.scala index 4793a8c8..cc31fb09 100644 --- a/wasm/src/main/scala/Compiler.scala +++ b/wasm/src/main/scala/Compiler.scala @@ -49,15 +49,14 @@ object Compiler { } yield { val onlyModule = moduleSet.modules.head - val filteredClasses = onlyModule.classDefs.filter { c => - !ExcludedClasses.contains(c.className) - } + // Sort for stability + val sortedClasses = onlyModule.classDefs.sortBy(_.className) - filteredClasses.sortBy(_.className).foreach(showLinkedClass(_)) + sortedClasses.foreach(showLinkedClass(_)) - Preprocessor.preprocess(filteredClasses)(context) + Preprocessor.preprocess(sortedClasses)(context) println("preprocessed") - filteredClasses.foreach { clazz => + sortedClasses.foreach { clazz => builder.transformClassDef(clazz) } onlyModule.topLevelExports.foreach { tle => @@ -71,14 +70,6 @@ object Compiler { } } - private val ExcludedClasses: Set[ir.Names.ClassName] = { - import ir.Names._ - HijackedClasses ++ // hijacked classes - Set( - ClassClass // java.lang.Class - ) - } - private def showLinkedClass(clazz: LinkedClass): Unit = { val writer = new java.io.PrintWriter(System.out) val printer = new LinkedClassPrinter(writer) diff --git a/wasm/src/main/scala/ir2wasm/LibraryPatches.scala b/wasm/src/main/scala/ir2wasm/LibraryPatches.scala index 3d576323..968dcc26 100644 --- a/wasm/src/main/scala/ir2wasm/LibraryPatches.scala +++ b/wasm/src/main/scala/ir2wasm/LibraryPatches.scala @@ -65,6 +65,15 @@ object LibraryPatches { private val MethodPatches: Map[ClassName, List[MethodDef]] = { Map( + ObjectClass -> List( + // TODO Remove this patch when we support getClass() and full string concatenation + MethodDef( + EMF, m("toString", Nil, T), NON, + Nil, ClassType(BoxedStringClass), + Some(StringLiteral("[object]")) + )(EOH, NOV) + ), + BoxedCharacterClass.withSuffix("$") -> List( MethodDef( EMF, m("toString", List(C), T), NON, diff --git a/wasm/src/main/scala/ir2wasm/Preprocessor.scala b/wasm/src/main/scala/ir2wasm/Preprocessor.scala index b39e8d26..824b54a6 100644 --- a/wasm/src/main/scala/ir2wasm/Preprocessor.scala +++ b/wasm/src/main/scala/ir2wasm/Preprocessor.scala @@ -17,29 +17,24 @@ object Preprocessor { for (clazz <- classes) preprocess(clazz) - for (clazz <- classes) { - if (clazz.className != IRNames.ObjectClass) - collectAbstractMethodCalls(clazz) - } + for (clazz <- classes) + collectAbstractMethodCalls(clazz) } private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { clazz.kind match { - case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface => + case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface | ClassKind.HijackedClass => collectMethods(clazz) case ClassKind.JSClass | ClassKind.JSModuleClass | ClassKind.NativeJSModuleClass | - ClassKind.AbstractJSType | ClassKind.NativeJSClass | ClassKind.HijackedClass => + ClassKind.AbstractJSType | ClassKind.NativeJSClass => ??? } } private def collectMethods(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { - val infos = - if (clazz.name.name == IRNames.ObjectClass) Nil - else - clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method => - makeWasmFunctionInfo(clazz, method) - } + val infos = clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method => + makeWasmFunctionInfo(clazz, method) + } ctx.putClassInfo( clazz.name.name, new WasmClassInfo( @@ -48,7 +43,8 @@ object Preprocessor { infos, clazz.fields.collect { case f: IRTrees.FieldDef => Names.WasmFieldName(f.name.name) }, clazz.superClass.map(_.name), - clazz.interfaces.map(_.name) + clazz.interfaces.map(_.name), + clazz.ancestors ) ) } diff --git a/wasm/src/main/scala/ir2wasm/WasmBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmBuilder.scala index 0e64dc66..ec8b46f8 100644 --- a/wasm/src/main/scala/ir2wasm/WasmBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmBuilder.scala @@ -24,10 +24,11 @@ class WasmBuilder { def transformClassDef(clazz: LinkedClass)(implicit ctx: WasmContext) = { clazz.kind match { - case ClassKind.ModuleClass => transformModuleClass(clazz) - case ClassKind.Class => transformClass(clazz) - case ClassKind.Interface => transformInterface(clazz) - case _ => + case ClassKind.ModuleClass => transformModuleClass(clazz) + case ClassKind.Class => transformClass(clazz) + case ClassKind.HijackedClass => transformHijackedClass(clazz) + case ClassKind.Interface => transformInterface(clazz) + case _ => ??? } } @@ -68,15 +69,10 @@ class WasmBuilder { ) ctx.addGCType(structType) - // Do not generate methods in Object for now - if (clazz.name.name == IRNames.ObjectClass) - clazz.methods.filter(_.name.name == IRNames.NoArgConstructorName).foreach { method => - genFunction(clazz, method) - } - else - clazz.methods.foreach { method => - genFunction(clazz, method) - } + // implementation of methods + clazz.methods.foreach { method => + genFunction(clazz, method) + } structType } @@ -242,6 +238,12 @@ class WasmBuilder { transformClassCommon(clazz) } + private def transformHijackedClass(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { + clazz.methods.foreach { method => + genFunction(clazz, method) + } + } + private def transformInterface(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { assert(clazz.kind == ClassKind.Interface) // gen itable type @@ -389,7 +391,9 @@ class WasmBuilder { // Receiver type for non-constructor methods needs to be Object type because params are invariant // Otherwise, vtable can't be a subtype of the supertype's subtype // Constructor can use the exact type because it won't be registered to vtables. - if (method.flags.namespace.isConstructor) + if (clazz.kind == ClassKind.HijackedClass) + transformType(IRTypes.BoxedClassToPrimType(clazz.name.name)) + else if (method.flags.namespace.isConstructor) WasmRefNullType(WasmHeapType.Type(WasmTypeName.WasmStructTypeName(clazz.name.name))) else WasmRefNullType(WasmHeapType.ObjectType), diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index dd621366..1a5d85ba 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -63,6 +63,7 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti transformApplyStatically(t) case t: IRTrees.Apply => transformApply(t) case t: IRTrees.ApplyDynamicImport => ??? + case t: IRTrees.AsInstanceOf => transformAsInstanceOf(t) case t: IRTrees.Block => transformBlock(t) case t: IRTrees.Labeled => transformLabeled(t) case t: IRTrees.Return => transformReturn(t) @@ -180,8 +181,9 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti val wasmArgs = t.args.flatMap(transformTree) val receiverClassName = t.receiver.tpe match { - case ClassType(className) => className - case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") + case ClassType(className) => className + case prim: IRTypes.PrimType => IRTypes.PrimTypeToBoxedClass(prim) + case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") } val receiverClassInfo = ctx.getClassInfo(receiverClassName) @@ -228,6 +230,18 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti TypeIdx(method.toWasmFunctionType()(ctx).name) ) ) + } else if (receiverClassInfo.kind == ClassKind.HijackedClass) { + // statically resolved call + val info = receiverClassInfo.getMethodInfo(t.method.name) + val castIfNeeded = + if (receiverClassName == IRNames.BoxedStringClass && t.receiver.tpe == ClassType(IRNames.BoxedStringClass)) + List(REF_CAST(HeapType(Types.WasmHeapType.Type(WasmStructTypeName.string)))) + else + Nil + pushReceiver ++ castIfNeeded ++ wasmArgs ++ + List( + CALL(FuncIdx(info.name)) + ) } else { // virtual dispatch val (methodIdx, info) = ctx .calculateVtable(receiverClassName) @@ -401,6 +415,17 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti case BinaryOp.Long_>>> => longShiftOp(I64_SHR_U) case BinaryOp.Long_>> => longShiftOp(I64_SHR_S) + // New in 1.11 + case BinaryOp.String_charAt => + transformTree(binary.lhs) ++ // push the string + List( + STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array + ) ++ + transformTree(binary.rhs) ++ // push the index + List( + ARRAY_GET_U(TypeIdx(WasmArrayTypeName.stringData)) // access the element of the array + ) + case _ => transformElementaryBinaryOp(binary) } } @@ -479,9 +504,6 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti case BinaryOp.Double_<= => F64_LE case BinaryOp.Double_> => F64_GT case BinaryOp.Double_>= => F64_GE - - // // New in 1.11 - case BinaryOp.String_charAt => ??? // TODO } lhsInstrs ++ rhsInstrs :+ operation } @@ -539,6 +561,24 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti } } + private def transformAsInstanceOf(tree: IRTrees.AsInstanceOf): List[WasmInstr] = { + val exprInstrs = transformTree(tree.expr) + + val sourceTpe = tree.expr.tpe + val targetTpe = tree.tpe + + if (IRTypes.isSubtype(sourceTpe, targetTpe)(isSubclass(_, _))) { + // Common case where no cast is necessary + exprInstrs + } else { + println(tree) + ??? + } + } + + private def isSubclass(subClass: IRNames.ClassName, superClass: IRNames.ClassName): Boolean = + ctx.getClassInfo(subClass).ancestors.contains(superClass) + private def transformVarRef(r: IRTrees.VarRef): LOCAL_GET = { val name = WasmLocalName.fromIR(r.ident.name) LOCAL_GET(LocalIdx(name)) diff --git a/wasm/src/main/scala/wasm4s/WasmContext.scala b/wasm/src/main/scala/wasm4s/WasmContext.scala index 60bc3b13..e70c33c6 100644 --- a/wasm/src/main/scala/wasm4s/WasmContext.scala +++ b/wasm/src/main/scala/wasm4s/WasmContext.scala @@ -129,7 +129,8 @@ object WasmContext { private var _methods: List[WasmFunctionInfo], private val fields: List[WasmFieldName], val superClass: Option[IRNames.ClassName], - val interfaces: List[IRNames.ClassName] + val interfaces: List[IRNames.ClassName], + val ancestors: List[IRNames.ClassName] ) { def isInterface = kind == ClassKind.Interface @@ -145,6 +146,12 @@ object WasmContext { } } + def getMethodInfo(methodName: IRNames.MethodName): WasmFunctionInfo = { + methods.find(_.name.methodName == methodName.nameString).getOrElse { + throw new IllegalArgumentException(s"Cannot find method ${methodName.nameString} in class ${name.nameString}") + } + } + def getFieldIdx(name: WasmFieldName): WasmImmediate.StructFieldIdx = fields.indexWhere(_ == name) match { case i if i < 0 => throw new Error(s"Field not found: $name")