Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Redesign top-level export implementation to handle mutable vars.
Browse files Browse the repository at this point in the history
The previous strategy for top-level export could not lead to any
workable support of mutable vars. We now follow a strategy similar
to what the JS backend uses in its NoModule configuration.

Previously, we compiled top-level exports as Wasm global exports.
We then had a postprocessing step in the loader to extract the
`.value` of the `WAGlobal` instances. This correctly captured the
exported value right after the start function has finished.
However, it could not update the exported value after a mutable
static field was reassigned.

Now, we turn everything everything around. Instead of the JS loader
and wrapper *pulling* values from Wasm, the Wasm code *pushes*
updates to top-level exports to JavaScript. That means we declare
setter functions from JavaScript, and that we (counter-intuitively)
*import* those setters into Wasm. Wasm calls the setters in the
start function for all top-level exports, and again when assigning
static fields that are exported.
  • Loading branch information
sjrd committed May 14, 2024
1 parent 199fe7d commit 208202a
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,10 @@ class ClassEmitter(coreSpec: CoreSpec) {
def transformTopLevelExport(
topLevelExport: LinkedTopLevelExport
)(implicit ctx: WasmContext): Unit = {
genTopLevelExportSetter(topLevelExport.exportName)
topLevelExport.tree match {
case d: TopLevelJSClassExportDef => genDelayedTopLevelExport(d.exportName)
case d: TopLevelModuleExportDef => genDelayedTopLevelExport(d.exportName)
case d: TopLevelMethodExportDef => transformTopLevelMethodExportDef(d)
case d: TopLevelFieldExportDef => transformTopLevelFieldExportDef(d)
case d: TopLevelMethodExportDef => transformTopLevelMethodExportDef(d)
case _ => ()
}
}

Expand Down Expand Up @@ -1089,6 +1088,21 @@ class ClassEmitter(coreSpec: CoreSpec) {
fb.buildAndAddToModule()
}

/** Generates the function import for a top-level export setter. */
private def genTopLevelExportSetter(exportedName: String)(implicit ctx: WasmContext): Unit = {
val functionName = genFunctionName.forTopLevelExportSetter(exportedName)
val functionSig = wamod.FunctionSignature(List(watpe.RefType.anyref), Nil)
val functionType = ctx.moduleBuilder.signatureToTypeName(functionSig)

ctx.moduleBuilder.addImport(
wamod.Import(
"__scalaJSExportSetters",
exportedName,
wamod.ImportDesc.Func(functionName, functionType)
)
)
}

private def transformTopLevelMethodExportDef(
exportDef: TopLevelMethodExportDef
)(implicit ctx: WasmContext): Unit = {
Expand All @@ -1108,48 +1122,6 @@ class ClassEmitter(coreSpec: CoreSpec) {
method.body,
resultType = AnyType
)

/* We cannot directly export the function because it would not be considered
* a `function`. Instead, we will explicitly create a closure wrapper in the
* start function and export that instead.
*/
genDelayedTopLevelExport(exportedName)
}

private def transformTopLevelFieldExportDef(
exportDef: TopLevelFieldExportDef
)(implicit ctx: WasmContext): Unit = {
val exprt = wamod.Export.Global(
exportDef.exportName,
genGlobalName.forStaticField(exportDef.field.name)
)
ctx.addExport(exprt)
}

/** Generates a delayed top-level export global, to be initialized in the `start` function.
*
* Some top-level exports need to be initialized by run-time code because they need to call
* initializing functions:
*
* - methods with a `...rest` need to be initialized with the `closureRestNoArg` helper.
* - JS classes need to be initialized with their `loadJSClass` helper.
* - JS modules need to be initialized with their `loadModule` helper.
*
* For all of those, we use `genDelayedTopLevelExport` to generate a Wasm global initialized with
* `null` and to export it. We actually initialize the global in the `start` function (see
* `genStartFunction()` in `WasmContext`).
*/
private def genDelayedTopLevelExport(exportedName: String)(implicit ctx: WasmContext): Unit = {
val globalName = genGlobalName.forTopLevelExport(exportedName)
ctx.addGlobal(
wamod.Global(
globalName,
watpe.RefType.anyref,
wamod.Expr(List(wa.REF_NULL(watpe.HeapType.None))),
isMutable = true
)
)
ctx.addExport(wamod.Export.Global(exportedName, globalName))
}

private def genFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,12 @@ final class Emitter(config: Emitter.Config) {
// Initialize the top-level exports that require it

for (tle <- topLevelExportDefs) {
// Load the (initial) exported value on the stack
tle.tree match {
case TopLevelJSClassExportDef(_, exportName) =>
instrs += wa.CALL(genFunctionName.loadJSClass(tle.owningClass))
instrs += wa.GLOBAL_SET(genGlobalName.forTopLevelExport(tle.exportName))
case TopLevelModuleExportDef(_, exportName) =>
instrs += wa.CALL(genFunctionName.loadModule(tle.owningClass))
instrs += wa.GLOBAL_SET(genGlobalName.forTopLevelExport(tle.exportName))
case TopLevelMethodExportDef(_, methodDef) =>
instrs += ctx.refFuncWithDeclaration(genFunctionName.forExport(tle.exportName))
if (methodDef.restParam.isDefined) {
Expand All @@ -197,11 +196,15 @@ final class Emitter(config: Emitter.Config) {
} else {
instrs += wa.CALL(genFunctionName.makeExportedDef)
}
instrs += wa.GLOBAL_SET(genGlobalName.forTopLevelExport(tle.exportName))
case TopLevelFieldExportDef(_, _, _) =>
// Nothing to do
()
case TopLevelFieldExportDef(_, _, fieldIdent) =>
/* Usually redundant, but necessary if the static field is never
* explicitly set and keeps its default (zero) value instead.
*/
instrs += wa.GLOBAL_GET(genGlobalName.forStaticField(fieldIdent.name))
}

// Call the export setter
instrs += wa.CALL(genFunctionName.forTopLevelExportSetter(tle.exportName))
}

// Emit the module initializers
Expand Down Expand Up @@ -283,24 +286,27 @@ final class Emitter(config: Emitter.Config) {
(moduleImport, item)
}).unzip

/* TODO This is not correct for exported *vars*, since they won't receive
* updates from mutations after loading.
*/
val reExportStats = for {
val (exportDecls, exportSetters) = (for {
exportName <- module.topLevelExports.map(_.exportName)
} yield {
s"export let $exportName = __exports.$exportName;"
}
val identName = s"exported$exportName"
val decl = s"let $identName;\nexport { $identName as $exportName };"
val setter = s" $exportName: (x) => $identName = x,"
(decl, setter)
}).unzip

s"""
|${moduleImports.mkString("\n")}
|
|import { load as __load } from './${config.loaderModuleName}';
|const __exports = await __load('./${wasmFileName}', {
|
|${exportDecls.mkString("\n")}
|
|await __load('./${wasmFileName}', {
|${importedModulesItems.mkString("\n")}
|}, {
|${exportSetters.mkString("\n")}
|});
|
|${reExportStats.mkString("\n")}
""".stripMargin.trim() + "\n"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,19 @@ private class FunctionEmitter private (
}

case sel: SelectStatic =>
val fieldName = sel.field.name
val globalName = genGlobalName.forStaticField(fieldName)

genTree(t.rhs, sel.tpe)
instrs += wa.GLOBAL_SET(genGlobalName.forStaticField(sel.field.name))
instrs += wa.GLOBAL_SET(globalName)

// Update top-level export mirrors
val classInfo = ctx.getClassInfo(fieldName.className)
val mirrors = classInfo.staticFieldMirrors.getOrElse(fieldName, Nil)
for (exportedName <- mirrors) {
instrs += wa.GLOBAL_GET(globalName)
instrs += wa.CALL(genFunctionName.forTopLevelExportSetter(exportedName))
}

case sel: ArraySelect =>
genTreeAuto(sel.array)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,12 @@ const scalaJSHelpers = {
},
}

export async function load(wasmFileURL, importedModules) {
export async function load(wasmFileURL, importedModules, exportSetters) {
const myScalaJSHelpers = { ...scalaJSHelpers, idHashCodeMap: new WeakMap() };
const importsObj = {
"__scalaJSHelpers": myScalaJSHelpers,
"__scalaJSImports": importedModules,
"__scalaJSExportSetters": exportSetters,
};
const resolvedURL = new URL(wasmFileURL, import.meta.url);
var wasmModulePromise;
Expand All @@ -326,24 +327,7 @@ export async function load(wasmFileURL, importedModules) {
} else {
wasmModulePromise = WebAssembly.instantiateStreaming(fetch(resolvedURL), importsObj);
}
const wasmModule = await wasmModulePromise;
const exports = wasmModule.instance.exports;

const userExports = Object.create(null);
for (const exportName of Object.getOwnPropertyNames(exports)) {
const exportValue = exports[exportName];
if (exportValue instanceof WebAssembly.Global) {
Object.defineProperty(userExports, exportName, {
configurable: true,
enumerable: true,
get: () => exportValue.value,
});
} else {
userExports[exportName] = exportValue;
}
}
Object.freeze(userExports);
return userExports;
await wasmModulePromise;
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ object Preprocessor {
def preprocess(classes: List[LinkedClass], tles: List[LinkedTopLevelExport])(implicit
ctx: WasmContext
): Unit = {
val staticFieldMirrors = computeStaticFieldMirrors(tles)

for (clazz <- classes)
preprocess(clazz)
preprocess(clazz, staticFieldMirrors.getOrElse(clazz.className, Map.empty))

val collector = new AbstractMethodCallCollector(ctx)
for (clazz <- classes)
Expand All @@ -29,7 +31,28 @@ object Preprocessor {
ctx.assignBuckets(classes)
}

private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
private def computeStaticFieldMirrors(
tles: List[LinkedTopLevelExport]
): Map[ClassName, Map[FieldName, List[String]]] = {
var result = Map.empty[ClassName, Map[FieldName, List[String]]]
for (tle <- tles) {
tle.tree match {
case TopLevelFieldExportDef(_, exportName, FieldIdent(fieldName)) =>
val className = tle.owningClass
val mirrors = result.getOrElse(className, Map.empty)
val newExportNames = exportName :: mirrors.getOrElse(fieldName, Nil)
val newMirrors = mirrors.updated(fieldName, newExportNames)
result = result.updated(className, newMirrors)

case _ =>
}
}
result
}

private def preprocess(clazz: LinkedClass, staticFieldMirrors: Map[FieldName, List[String]])(
implicit ctx: WasmContext
): Unit = {
val kind = clazz.kind

val allFieldDefs: List[FieldDef] =
Expand Down Expand Up @@ -96,6 +119,7 @@ object Preprocessor {
hasRuntimeTypeInfo,
clazz.jsNativeLoadSpec,
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap,
staticFieldMirrors,
_itableIdx = -1
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ object VarGen {
def forStaticField(fieldName: IRFieldName): GlobalName =
GlobalName(s"static.${fieldName.nameString}")

def forTopLevelExport(exportName: String): GlobalName =
GlobalName(s"export.$exportName")

def forJSPrivateField(fieldName: IRFieldName): GlobalName =
GlobalName(s"jspfield.${fieldName.nameString}")

Expand Down Expand Up @@ -103,6 +100,8 @@ object VarGen {

def forExport(exportedName: String): FunctionName =
make("export", exportedName)
def forTopLevelExportSetter(exportedName: String): FunctionName =
make("setexport", exportedName)

def loadModule(clazz: ClassName): FunctionName =
make("loadModule", clazz.nameString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ object WasmContext {
val hasRuntimeTypeInfo: Boolean,
val jsNativeLoadSpec: Option[JSNativeLoadSpec],
val jsNativeMembers: Map[MethodName, JSNativeLoadSpec],
val staticFieldMirrors: Map[FieldName, List[String]],
private var _itableIdx: Int
) {
private val fieldIdxByName: Map[FieldName, Int] =
Expand Down

0 comments on commit 208202a

Please sign in to comment.