Skip to content

Commit

Permalink
WASM: embedding models working
Browse files Browse the repository at this point in the history
  • Loading branch information
jpohhhh committed Aug 24, 2024
1 parent d2c5d09 commit 47291fc
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 45 deletions.
43 changes: 40 additions & 3 deletions docs/fonnx_minilm_worker.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/esm/ort.min.js';
import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.19.0/dist/ort.all.min.mjs';

let session = null;

Expand All @@ -8,6 +8,29 @@ let session = null;
const cores = navigator.hardwareConcurrency;
ort.env.wasm.numThreads = Math.max(1, Math.min(Math.floor(cores / 2), cores));

function toBigInt64Array(wordpieces) {
// Create a buffer with the correct size
const buffer = new ArrayBuffer(wordpieces.length * 8); // 8 bytes per BigInt64
const view = new BigInt64Array(buffer);

for (let i = 0; i < wordpieces.length; i++) {
const value = wordpieces[i];
// console.log(`Original value at index ${i}:`, value, "of type", typeof value);

if (typeof value === 'bigint') {
view[i] = value;
} else if (typeof value === 'number') {
view[i] = BigInt(Math.floor(value)); // Ensure integer
} else {
throw new Error(`Unsupported type at index ${i}: ${typeof value}`);
}

// console.log(`Converted value at index ${i}:`, view[i], "of type", typeof view[i]);
}

return view;
}

self.onmessage = async e => {
const { action, modelArrayBuffer, wordpieces, messageId } = e.data;
try {
Expand All @@ -18,6 +41,7 @@ self.onmessage = async e => {
});
console.log('MiniLm model loaded');
self.postMessage({ messageId, action: 'modelLoaded' });
console.log('New log line appearing');
} else if (action === 'runInference') {
if (!session) {
console.error('Session does not exist');
Expand All @@ -31,20 +55,33 @@ self.onmessage = async e => {
}
// Prepare tensors and run the inference session
const shape = [1, wordpieces.length];
const inputIdsTensor = new ort.Tensor('int64', wordpieces.map(x => BigInt(x)), shape);
console.time("Creating tensors");
console.time("inputIdsTensor");
const inputIdsTensor = new ort.Tensor('int64', toBigInt64Array(wordpieces), shape);
console.timeEnd("inputIdsTensor");
console.time("tokenTypeIdsTensor");
const tokenTypeIdsTensor = new ort.Tensor('int64', new BigInt64Array(shape[0] * shape[1]).fill(0n), shape);
console.timeEnd("tokenTypeIdsTensor");
console.time("attentionMaskTensor");
const attentionMaskTensor = new ort.Tensor('int64', new BigInt64Array(shape[0] * shape[1]).fill(1n), shape);
console.timeEnd("attentionMaskTensor");
console.timeEnd("Creating tensors");

console.time("Inference");

const results = await session.run({
input_ids: inputIdsTensor,
token_type_ids: tokenTypeIdsTensor,
attention_mask: attentionMaskTensor,
});
console.timeEnd("Inference");
console.time("Posting result");
const embeddings = results.embeddings.data;
const message = { messageId, action: 'inferenceResult', embeddings };
self.postMessage(message);
console.timeEnd("Posting result");
}
} catch (error) {
self.postMessage({ messageId, action: 'error', error: error.message });
}
};
};
30 changes: 30 additions & 0 deletions example/.metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file tracks properties of this Flutter project.
# Used by Flutter tool to assess capabilities and perform upgrades etc.
#
# This file should be version controlled and should not be manually edited.

version:
revision: "5874a72aa4c779a02553007c47dacbefba2374dc"
channel: "stable"

project_type: app

# Tracks metadata for the flutter migrate command
migration:
platforms:
- platform: root
create_revision: 5874a72aa4c779a02553007c47dacbefba2374dc
base_revision: 5874a72aa4c779a02553007c47dacbefba2374dc
- platform: web
create_revision: 5874a72aa4c779a02553007c47dacbefba2374dc
base_revision: 5874a72aa4c779a02553007c47dacbefba2374dc

# User provided section

# List of Local paths (relative to this file) that should be
# ignored by the migrate tool.
#
# Files that are not part of the templates will be ignored by default.
unmanaged_files:
- 'lib/main.dart'
- 'ios/Runner.xcodeproj/project.pbxproj'
11 changes: 8 additions & 3 deletions example/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ dependencies:
url: https://github.com/Telosnex/libmonet.git
ref: main
record: ^5.0.4
audioplayers: ^5.2.1
file_picker: ^6.1.1
audioplayers: ^6.0.0
file_picker: ^8.1.2
collection: ^1.18.0

dependency_overrides:
audioplayers_web:
git:
url: https://github.com/bluefireteam/audioplayers.git
path: packages/audioplayers_web
ref: main
dev_dependencies:
integration_test:
sdk: flutter
Expand Down
41 changes: 17 additions & 24 deletions example/web/index.html
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
<!DOCTYPE html>
<html>

<head>
<!--
If you are serving your web app in a path other than the root, change the
href value below to reflect the base path you are serving from.
The path provided below has to start and end with a slash "/" in order for
it to work correctly.
For more details:
* https://developer.mozilla.org/en-US/docs/Web/HTML/Element/base
This is a placeholder for base href that will be replaced by the value of
the `--base-href` argument provided to `flutter build`.
-->
<base href="$FLUTTER_BASE_HREF">

<meta charset="UTF-8">
Expand All @@ -21,30 +33,10 @@
<link rel="manifest" href="manifest.json">
<!-- Enables WASM on GH pages. Via https://github.com/orgs/community/discussions/46419 -->
<script src="enable-threads.js"></script>
<script>
// The value below is injected by flutter build, do not touch.
const serviceWorkerVersion = null;
</script>
<!-- This script adds the flutter initialization JS code -->
<script src="flutter.js" defer></script>

</head>

<body>
<script>
window.addEventListener('load', function (ev) {
// Download main.dart.js
_flutter.loader.loadEntrypoint({
serviceWorker: {
serviceWorkerVersion: serviceWorkerVersion,
},
onEntrypointLoaded: function (engineInitializer) {
engineInitializer.initializeEngine().then(function (appRunner) {
appRunner.runApp();
});
}
});
});
</script>
<script src="flutter_bootstrap.js" async></script>
<!-- FONNX implementations start here -->
<!-- REQUIRED FOR: Magika -->
<script type="module">
Expand All @@ -62,6 +54,7 @@
<script type="module">
import './silero_vad_init.js';
</script>
<!-- FONNX implementations end here -->
</body>

</html>
</html>
43 changes: 40 additions & 3 deletions example/web/minilm_worker.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/esm/ort.min.js';
import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.19.0/dist/ort.all.min.mjs';

let session = null;

Expand All @@ -8,6 +8,29 @@ let session = null;
const cores = navigator.hardwareConcurrency;
ort.env.wasm.numThreads = Math.max(1, Math.min(Math.floor(cores / 2), cores));

function toBigInt64Array(wordpieces) {
// Create a buffer with the correct size
const buffer = new ArrayBuffer(wordpieces.length * 8); // 8 bytes per BigInt64
const view = new BigInt64Array(buffer);

for (let i = 0; i < wordpieces.length; i++) {
const value = wordpieces[i];
// console.log(`Original value at index ${i}:`, value, "of type", typeof value);

if (typeof value === 'bigint') {
view[i] = value;
} else if (typeof value === 'number') {
view[i] = BigInt(Math.floor(value)); // Ensure integer
} else {
throw new Error(`Unsupported type at index ${i}: ${typeof value}`);
}

// console.log(`Converted value at index ${i}:`, view[i], "of type", typeof view[i]);
}

return view;
}

self.onmessage = async e => {
const { action, modelArrayBuffer, wordpieces, messageId } = e.data;
try {
Expand All @@ -18,6 +41,7 @@ self.onmessage = async e => {
});
console.log('MiniLm model loaded');
self.postMessage({ messageId, action: 'modelLoaded' });
console.log('New log line appearing');
} else if (action === 'runInference') {
if (!session) {
console.error('Session does not exist');
Expand All @@ -31,20 +55,33 @@ self.onmessage = async e => {
}
// Prepare tensors and run the inference session
const shape = [1, wordpieces.length];
const inputIdsTensor = new ort.Tensor('int64', wordpieces.map(x => BigInt(x)), shape);
console.time("Creating tensors");
console.time("inputIdsTensor");
const inputIdsTensor = new ort.Tensor('int64', toBigInt64Array(wordpieces), shape);
console.timeEnd("inputIdsTensor");
console.time("tokenTypeIdsTensor");
const tokenTypeIdsTensor = new ort.Tensor('int64', new BigInt64Array(shape[0] * shape[1]).fill(0n), shape);
console.timeEnd("tokenTypeIdsTensor");
console.time("attentionMaskTensor");
const attentionMaskTensor = new ort.Tensor('int64', new BigInt64Array(shape[0] * shape[1]).fill(1n), shape);
console.timeEnd("attentionMaskTensor");
console.timeEnd("Creating tensors");

console.time("Inference");

const results = await session.run({
input_ids: inputIdsTensor,
token_type_ids: tokenTypeIdsTensor,
attention_mask: attentionMaskTensor,
});
console.timeEnd("Inference");
console.time("Posting result");
const embeddings = results.embeddings.data;
const message = { messageId, action: 'inferenceResult', embeddings };
self.postMessage(message);
console.timeEnd("Posting result");
}
} catch (error) {
self.postMessage({ messageId, action: 'error', error: error.message });
}
};
};
11 changes: 5 additions & 6 deletions lib/models/minilml6v2/mini_lm_l6_v2_web.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import 'package:ml_linalg/linalg.dart';
MiniLmL6V2 getMiniLmL6V2(String path) => MiniLmL6V2Web(path);

@JS('window.miniLmL6V2')
external JSPromise<JSAny?> sbertJs(
String modelPath, List<int> wordpieces);
external JSPromise<JSAny?> sbertJs(JSString modelPath, JSInt16Array wordpieces);

class MiniLmL6V2Web implements MiniLmL6V2 {
final String modelPath;
Expand All @@ -17,15 +16,15 @@ class MiniLmL6V2Web implements MiniLmL6V2 {

@override
Future<Vector> getEmbeddingAsVector(List<int> tokens) async {
final jsObject = await sbertJs(modelPath, tokens).toDart;
final jsObject =
await sbertJs(modelPath.toJS, Int16List.fromList(tokens).toJS).toDart;

if (jsObject == null) {
throw Exception('Embeddings returned from JS code are null');
}

final jsList = jsObject as List<dynamic>;
final jsList = jsObject as JSFloat32Array;
final vector = Vector.fromList(
Float32List.fromList(jsList.cast()),
jsList.toDart,
dtype: DType.float32,
).normalize();
return vector;
Expand Down
10 changes: 4 additions & 6 deletions lib/models/msmarcoMiniLmL6V3/msmarco_mini_lm_l6_v3_web.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ MsmarcoMiniLmL6V3 getMsmarcoMiniLmL6V3(String path) =>
MsmarcoMiniLmL6V3Web(path);

@JS('window.miniLmL6V2')
external JSPromise<JSAny?> sbertJs(
String modelPath, JSUint8Array wordpieces);
external JSPromise<JSAny?> sbertJs(JSString modelPath, JSInt16Array wordpieces);

class MsmarcoMiniLmL6V3Web implements MsmarcoMiniLmL6V3 {
final String modelPath;
Expand All @@ -19,15 +18,14 @@ class MsmarcoMiniLmL6V3Web implements MsmarcoMiniLmL6V3 {
@override
Future<Vector> getEmbeddingAsVector(List<int> tokens) async {
final jsObject =
await sbertJs(modelPath, Uint8List.fromList(tokens).toJS).toDart;
await sbertJs(modelPath.toJS, Int16List.fromList(tokens).toJS).toDart;

if (jsObject == null) {
throw Exception('Embeddings returned from JS code are null');
}

final jsList = (jsObject as List<dynamic>);
final jsList = jsObject as JSFloat32Array;
final vector = Vector.fromList(
Float32List.fromList(jsList.cast()),
jsList.toDart,
dtype: DType.float32,
).normalize();
return vector;
Expand Down

0 comments on commit 47291fc

Please sign in to comment.