Skip to content

Commit

Permalink
Multithreaded inference on web
Browse files Browse the repository at this point in the history
- Inference is defined as an async method, but, it blocks.
After a couple days of trying all avenues and looking at sample
apps, it looks like it is synchronous in that it will consume the
attention of the thread the `await session.run` is called on.
- Using Squadron to handle multi-threading didn't work. Now that
the JS function in index.html is loading the model and passing it
to a worker, it's possible it might.
- In any case, this shows exactly how to set up a worker that
A) does inference without blocking UI rendering
B) allows Dart code to `await` the result without blocking UI
- This process was frustrating and fraught, there's a surprising
lack of info and examples around ONNX web. Most seem to consume
it via diffusers.js/transformers.js. ONNX web was a separate library from the
rest of the ONNX runtime until sometime around late 2022. The examples
still use that library, and the examples use simple enough models that it's
hard to catch whether they are blocking without falling back to dev tools.
- Its absolutely crucial when debugging speed locally to make sure you're loading
the ONNX version you expect (i.e. wasm AND threaded AND simd). The easiest
way to check is network loads in Dev Tools, sort by size, and look for the .wasm
file to A) be loaded B) include wasm, simd, and threaded in the filename.
- Two things can prevent that:
-- CORS nonsense with Flutter serving itself in debug mode:
--- see here, nagadomi/nunif#34
--- note that the extension became adware, you should have Chrome set up its
permissions such that it isn't run until you click it. Also, note that you have to do
that each time the Flutter web app in debug mode's port changes.
-- MIME type issues
--- Even after that, I would see errors in console logs about the MIME type of the
.wasm being incorrect and starting with the wrong bytes. That, again, seems due to
local Flutter serving of the web app. To work around that, you can download the
WASM files from the same CDN folder that hosts ort.min.js (see worker.js) and
also in worker.js, remove the // in front of ort.env.wasm.wasmPaths = "". That
indicates you've placed the WASM files next to index.html, which you should.
Note you just need the 4 .wasm files, no more, from the CDN.

Some performance review notes:
- `webgpu` as execution provider completely errors out, says "JS executor
not supported in the ONNX version" (1.16.3)
- `webgl` throws "Cannot read properties of null (reading 'irVersion')"
- Tested perf by varying wasm / simd / thread and thread count on M2 MacBook Air 16 GB ram, Chrome 120
- Landed on simd & thread count = 1/2 of cores as best performing
-- first # is minilm l6v2, second is minilm l6v3, average inference time for 200 / 400 words
-- 4 threads: 526 ms / 2196 ms
-- simd 4 threads: 86 ms / 214 ms
-- simd 8 threads:  106 ms / 260 ms
-- simd 128 threads: 2879 ms / skipped
-- simd navigator.hardwareConcurrency threads (8): 107 ms / 222 ms
  • Loading branch information
jpohhhh committed Dec 15, 2023
1 parent 7e960d9 commit b449c06
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 42 deletions.
115 changes: 73 additions & 42 deletions example/web/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -62,52 +62,83 @@
-->
<!-- REQUIRED FOR: ALL models. -->
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<!-- REQUIRED FOR: Mini LM L6 V2 -->
<!-- REQUIRED FOR: MiniLM models. -->
<script>
let cachedModelPath = null
let session = null

// use an async context to call onnxruntime functions.
window.miniLmL6V2 = async function (modelPath, wordpieces) {
try {
if (modelPath !== cachedModelPath) {
console.time("model loading")
const options = {
// WebGPU would be good, but it's broken currently as of 2023 10 10 (Chrome 117.0.5938.149)
// Generally it seems like ONNX might not support it quite fully, yet.
executionProviders: ["wasm", "cpu"],
};
session = await ort.InferenceSession.create(modelPath, options);
cachedModelPath = modelPath
console.timeEnd("model loading")
}
// Get the number of logical processors available.
const cores = navigator.hardwareConcurrency;

console.time("one embedding inference")
// Check if wordpieces is a single array or array of arrays
let bigIntWordpieces;
if (Array.isArray(wordpieces[0])) {
bigIntWordpieces = wordpieces.flat().map(x => BigInt(x));
} else {
bigIntWordpieces = wordpieces.map(x => BigInt(x));
}
// Ensure at least 1 and at most half the number of hardwareConcurrency.
// Testing showed using all cores was 10% slower than using half.
// Tested on MBA M2 with a natural value of 8 for navigator.hardwareConcurrency.
ort.env.wasm.numThreads = Math.max(1, Math.min(Math.floor(cores / 2), cores));
let cachedModelPath = null;
let modelPromise = null;

const shape = [1, bigIntWordpieces.length]
const inputsIdsKey = "input_ids";
const inputsIdsTensor = new ort.Tensor('int64', bigIntWordpieces, shape);
const tokenTypeIdsKey = 'token_type_ids';
const tokenTypeIdsTensor = new ort.Tensor('int64', BigInt64Array.from(new Array(shape[0] * shape[1]).fill(0n)), shape);
const attentionMaskKey = "attention_mask";
const attentionMaskTensor = new ort.Tensor('int64', BigInt64Array.from(new Array(shape[0] * shape[1]).fill(1n)), shape);
const inputs = { inputsIdsKey: inputsIdsTensor, attentionMaskKey: attentionMaskTensor, tokenTypeIdsKey: tokenTypeIdsTensor };
const results = await session.run({ 'input_ids': inputsIdsTensor, 'attention_mask': attentionMaskTensor, 'token_type_ids': tokenTypeIdsTensor });
const embeddings = results.embeddings.data;
console.timeEnd("one embedding inference")

return embeddings;
} catch (e) {
console.log(`failed to inference ONNX model: ${e}.`);
return null;
const worker = new Worker('worker.js');

// Simplified logs for brevity; can be extended to log each property if required.
worker.onmessage = function (e) {
const { messageId, action, embeddings, error } = e.data;
if (action === "inferenceResult" && messageIdToResolve.has(messageId)) {
messageIdToResolve.get(messageId)(embeddings);
cleanup(messageId);
} else if (action === "error" && messageIdToReject.has(messageId)) {
messageIdToReject.get(messageId)(new Error(error));
cleanup(messageId);
}
};

const messageIdToResolve = new Map();
const messageIdToReject = new Map();

function cleanup(messageId) {
messageIdToResolve.delete(messageId);
messageIdToReject.delete(messageId);
}

function miniLmL6V2(modelPath, wordpieces) {
return new Promise((resolve, reject) => {
const messageId = Math.random().toString(36).substring(2);

messageIdToResolve.set(messageId, resolve);
messageIdToReject.set(messageId, reject);

// If model path has changed or model is not yet loaded, fetch and load the model.
if (cachedModelPath !== modelPath || !modelPromise) {
cachedModelPath = modelPath;
modelPromise = fetch(modelPath)
.then(response => response.arrayBuffer())
.then(modelArrayBuffer => {
return new Promise((resolveLoad) => {
// Post the load model message to the worker.
worker.postMessage({
action: 'loadModel',
modelArrayBuffer,
messageId
}, [modelArrayBuffer]);

// Setup a one-time message listener for the "modelLoaded" message.
const onModelLoaded = (e) => {
if (e.data.action === 'modelLoaded' && e.data.messageId === messageId) {
worker.removeEventListener('message', onModelLoaded);
resolveLoad();
}
};
worker.addEventListener('message', onModelLoaded);
});
})
.catch(reject);
}

modelPromise.then(() => {
// Once the model is loaded, send the run inference message to the worker.
worker.postMessage({
action: 'runInference',
wordpieces,
messageId
});
}).catch(reject);
});
}
</script>
</body>
Expand Down
57 changes: 57 additions & 0 deletions example/web/worker.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
importScripts("https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js");

let session = null;

// Loading them locally preserves the mime-type. Flutter Web, at least in
// debug mode running locally, complains that the WASM was actually text/html
// and falls back to plain old CPU, which is very slow.
//
// However, in production, this requires hosting a 10 MB file.
// Therefore the default is to load from CDN, and the commented-out line
// is left in so there is a signal when developing.
// ort.env.wasm.wasmPaths = "";

// Ensure at least 1 and at most half the number of hardwareConcurrency.
// Testing showed using all cores was 15% slower than using half.
// Tested on MBA M2 with a value of 8 for navigator.hardwareConcurrency.
const cores = navigator.hardwareConcurrency;
ort.env.wasm.numThreads = Math.max(1, Math.min(Math.floor(cores / 2), cores));

self.onmessage = async e => {
const { action, modelArrayBuffer, wordpieces, messageId } = e.data;
try {
if (action === 'loadModel' && modelArrayBuffer) {
session = await ort.InferenceSession.create(modelArrayBuffer, {
executionProviders: ['webgl', 'wasm', 'cpu'],
});
self.postMessage({ messageId, action: 'modelLoaded' });
} else if (action === 'runInference') {
if (!session) {
console.error('Session does not exist');
self.postMessage({ messageId, action: 'error', error: 'Session does not exist' });
return;
}
if (!wordpieces) {
console.error('Wordpieces are not provided');
self.postMessage({ messageId, action: 'error', error: 'Wordpieces are not provided' });
return;
}
// Prepare tensors and run the inference session
const shape = [1, wordpieces.length];
const inputIdsTensor = new ort.Tensor('int64', wordpieces.map(x => BigInt(x)), shape);
const tokenTypeIdsTensor = new ort.Tensor('int64', new BigInt64Array(shape[0] * shape[1]).fill(0n), shape);
const attentionMaskTensor = new ort.Tensor('int64', new BigInt64Array(shape[0] * shape[1]).fill(1n), shape);

const results = await session.run({
input_ids: inputIdsTensor,
token_type_ids: tokenTypeIdsTensor,
attention_mask: attentionMaskTensor,
});
const embeddings = results.embeddings.data;
const message = { messageId, action: 'inferenceResult', embeddings };
self.postMessage(message);
}
} catch (error) {
self.postMessage({ messageId, action: 'error', error: error.message });
}
};

0 comments on commit b449c06

Please sign in to comment.