From 1f968a8479910a7820b3486e821e36ddec3faec4 Mon Sep 17 00:00:00 2001 From: Aron Homberg Date: Sat, 20 Jul 2024 18:22:28 +0200 Subject: [PATCH] feat: allow WASM modules to be imported via a user-land defined loader function #20876 --- js/common/lib/env.ts | 10 ++++++++++ js/web/lib/wasm/wasm-factory.ts | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index dbb5f8118363f..31c5d674b8f5d 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {env as envImpl} from './env-impl.js'; +import type {OrtWasmModule} from 'onnxruntime-web'; export declare namespace Env { export type WasmPathPrefix = string; @@ -80,6 +81,15 @@ export declare namespace Env { * @defaultValue `false` */ proxy?: boolean; + + /** + * Set a custom function to import the WebAssembly module from user-land code (Inversion of Control, IoC). + * + * @defaultValue undefined + */ + importWasmModule?: (mjsPathOverride: string, wasmPrefixOverride: string, numThreads: boolean) => Promise<[ + string | undefined, EmscriptenModuleFactory + ]>; } export interface WebGLFlags { diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index fb068ab42d04c..24dda41c1af86 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -108,8 +108,9 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag; const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm; const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag; + const importFunction = typeof flags.importWasmModule === "function" ? flags.importWasmModule : importWasmModule; - const [objectUrl, ortWasmFactory] = (await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1)); + const [objectUrl, ortWasmFactory] = (await importFunction(mjsPathOverride, wasmPrefixOverride, numThreads > 1)); let isTimeout = false;