Skip to content

Commit

Permalink
Enable Sync API for WebNN GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Nov 8, 2023
1 parent 65f66aa commit 41e4d21
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 16 deletions.
7 changes: 3 additions & 4 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="layoutBtns">
<!-- <label class="btn btn-outline-info active" id='nchw-label'>
<label class="btn btn-outline-info active" id='nchw-label'>
<input type="radio" name="layout" id="nchw" autocomplete="off" checked>NCHW
</label> -->
<label class="btn btn-outline-info btn-sm active">
</label>
<label class="btn btn-outline-info btn-sm">
<input type="radio" name="layout" id="nhwc" autocomplete="off">NHWC
</label>
</div>
Expand Down Expand Up @@ -199,7 +199,6 @@ <h2 class="text-uppercase text-info">No model selected</h2>
<div id="badge"></div>
<p>&copy;2022 <a href="https://webmachinelearning.github.io/">WebNN API</a></p>
</footer>
<script src="../sw.js"></script>
<script>
// This workaround is to fix jquery loading issue in electron.
// Refer to https://stackoverflow.com/questions/32621988/electron-jquery-is-not-defined.
Expand Down
40 changes: 30 additions & 10 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const imgElement = document.getElementById('feedElement');
imgElement.src = './images/test.jpg';
const camElement = document.getElementById('feedMediaElement');
let modelName = '';
let layout = 'nhwc';
let layout = 'nchw';
let instanceType = modelName + layout;
let rafReq;
let isFirstTimeLoad = true;
Expand All @@ -28,13 +28,31 @@ let stream = null;
let loadTime = 0;
let buildTime = 0;
let computeTime = 0;
const inputOptions = {
mean: [127.5, 127.5, 127.5],
std: [127.5, 127.5, 127.5],
inputLayout: 'nhwc',
labelUrl: './labels/labels1001.txt',
inputDimensions: [1, 224, 224, 3],
let inputOptions;
let outputDimensions;
const nhwcOptions = {
inputOptions: {
mean: [127.5, 127.5, 127.5],
std: [127.5, 127.5, 127.5],
inputLayout: "nhwc",

Check failure on line 37 in image_classification/main.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Strings must use singlequote
labelUrl: "./labels/labels1001.txt",
inputDimensions: [1, 224, 224, 3],
},
outputDimensions: [1, 1001],
};

const nchwOptions = {
inputOptions: {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
norm: true,
inputLayout: "nchw",
labelUrl: "./labels/labels1000.txt",
inputDimensions: [1, 3, 224, 224],
},
outputDimensions: [1, 1000],
};

let outputBuffer;
let deviceType = '';
let lastdeviceType = '';
Expand All @@ -51,7 +69,7 @@ async function fetchLabels(url) {
$(document).ready(() => {
$('.icdisplay').hide();
if (utils.isWebNN()) {
$('#webnn_cpu').click();
$('#webnn_gpu').click();
} else {
$('#polyfill_gpu').click();
}
Expand Down Expand Up @@ -228,10 +246,12 @@ async function main() {
lastBackend = lastBackend != backend ? backend : lastBackend;
}

inputOptions = layout == 'nchw' ? nchwOptions.inputOptions : nhwcOptions.inputOptions;
outputDimensions = layout == 'nchw' ? nchwOptions.outputDimensions : nhwcOptions.outputDimensions;
instanceType = modelName + layout;
labels = await fetchLabels(inputOptions.labelUrl);
outputBuffer =
new Float32Array(utils.sizeOfShape([1, 1001]));
new Float32Array(utils.sizeOfShape(outputDimensions));
isFirstTimeLoad = false;
console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
// UI shows model loading progress
Expand All @@ -242,7 +262,7 @@ async function main() {
contextOptions['powerPreference'] = powerPreference;
}

loadTime = await postAndListenMessage({action: 'load', options: contextOptions});
loadTime = await postAndListenMessage({action: 'load', options: {contextOptions, layout}});
console.log(` done in ${loadTime} ms.`);
// UI shows model building progress
await ui.showProgressComponent('done', 'current', 'pending');
Expand Down
137 changes: 137 additions & 0 deletions image_classification/mobilenet_nchw_sync.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
'use strict';

importScripts('../common/utils_worker.js');

// MobileNet V2 model with 'nchw' input layout
class MobileNetV2NchwSync {
constructor() {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.weightsUrl_ = '../test-data/models/mobilenetv2_nchw/weights/';
this.inputOptions = {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
norm: true,
inputLayout: 'nchw',
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, relu6 = true, options = {}) {
const prefix = this.weightsUrl_ + 'conv_' + name;
const weightsName = prefix + '_weight.npy';
const weights =
await buildConstantByNpy(this.builder_, weightsName);
const biasName = prefix + '_bias.npy';
const bias =
await buildConstantByNpy(this.builder_, biasName);
options.bias = bias;
if (relu6) {
return this.builder_.clamp(
this.builder_.conv2d(input, weights, options),
{minValue: 0, maxValue: 6});
} else {
return this.builder_.conv2d(input, weights, options);
}
}

async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'gemm_' + name;
const weightsName = prefix + '_weight.npy';
const weights = await buildConstantByNpy(this.builder_, weightsName);
const biasName = prefix + '_bias.npy';
const bias = await buildConstantByNpy(this.builder_, biasName);
const options = {c: bias, bTranspose: true};
return this.builder_.gemm(input, weights, options);
}

async buildLinearBottleneck_(
input, convNameArray, group, stride, shortcut = true) {
const conv1x1Relu6 = await this.buildConv_(input, convNameArray[0]);
const options = {
padding: [1, 1, 1, 1],
groups: group,
strides: [stride, stride],
};
const dwise3x3Relu6 = await this.buildConv_(
conv1x1Relu6, convNameArray[1], true, options);
const conv1x1Linear = await this.buildConv_(
dwise3x3Relu6, convNameArray[2], false);

if (shortcut) {
return this.builder_.add(input, conv1x1Linear);
}
return conv1x1Linear;
}

async load(contextOptions) {
this.context_ = navigator.ml.createContextSync(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
const data = this.builder_.input('input',
{type: 'float32', dimensions: this.inputOptions.inputDimensions});
const conv0 = await this.buildConv_(
data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]});
const conv1 = await this.buildConv_(
conv0, '2', true, {padding: [1, 1, 1, 1], groups: 32});
const conv2 = await this.buildConv_(conv1, '4', false);
const bottleneck0 = await this.buildLinearBottleneck_(
conv2, ['5', '7', '9'], 96, 2, false);
const bottleneck1 = await this.buildLinearBottleneck_(
bottleneck0, ['10', '12', '14'], 144, 1);
const bottleneck2 = await this.buildLinearBottleneck_(
bottleneck1, ['16', '18', '20'], 144, 2, false);
const bottleneck3 = await this.buildLinearBottleneck_(
bottleneck2, ['21', '23', '25'], 192, 1);
const bottleneck4 = await this.buildLinearBottleneck_(
bottleneck3, ['27', '29', '31'], 192, 1);
const bottleneck5 = await this.buildLinearBottleneck_(
bottleneck4, ['33', '35', '37'], 192, 2, false);
const bottleneck6 = await this.buildLinearBottleneck_(
bottleneck5, ['38', '40', '42'], 384, 1);
const bottleneck7 = await this.buildLinearBottleneck_(
bottleneck6, ['44', '46', '48'], 384, 1);
const bottleneck8 = await this.buildLinearBottleneck_(
bottleneck7, ['50', '52', '54'], 384, 1);
const bottleneck9 = await this.buildLinearBottleneck_(
bottleneck8, ['56', '58', '60'], 384, 1, false);
const bottleneck10 = await this.buildLinearBottleneck_(
bottleneck9, ['61', '63', '65'], 576, 1);
const bottleneck11 = await this.buildLinearBottleneck_(
bottleneck10, ['67', '69', '71'], 576, 1);
const bottleneck12 = await this.buildLinearBottleneck_(
bottleneck11, ['73', '75', '77'], 576, 2, false);
const bottleneck13 = await this.buildLinearBottleneck_(
bottleneck12, ['78', '80', '82'], 960, 1);
const bottleneck14 = await this.buildLinearBottleneck_(
bottleneck13, ['84', '86', '88'], 960, 1);
const bottleneck15 = await this.buildLinearBottleneck_(
bottleneck14, ['90', '92', '94'], 960, 1, false);

const conv3 = await this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(conv3);
const reshape = this.builder_.reshape(pool, [1, null]);
const gemm = await this.buildGemm_(reshape, '104');
return this.builder_.softmax(gemm);
}

build(outputOperand) {
this.graph_ = this.builder_.buildSync({'output': outputOperand});
}

// Release the constant tensors of a model
dispose() {
// dispose() is only available in webnn-polyfill
if (this.graph_ !== null && 'dispose' in this.graph_) {
this.graph_.dispose();
}
}

compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputs = {'output': outputBuffer};
this.context_.computeSync(this.graph_, inputs, outputs);
}
}
6 changes: 4 additions & 2 deletions image_classification/worker.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'use strict';

importScripts('./mobilenet_nhwc_sync.js');
importScripts('./mobilenet_nchw_sync.js');
importScripts('../common/utils_worker.js');

let netInstance = null;
Expand All @@ -19,12 +20,13 @@ onmessage = async (message) => {
switch (message.data.action) {
case 'load':
const loatStart = performance.now();
const contextOptions = message.data.options;
const contextOptions = message.data.options.contextOptions;
const layout = message.data.options.layout;
if (!isWebNN) {
// Set WebNN polyfill backend
await setPolyfillBackend(contextOptions.deviceType);
}
netInstance = new MobileNetV2NhwcSync();
netInstance = layout == 'nhwc' ? new MobileNetV2NhwcSync() : new MobileNetV2NchwSync();
outputOperand = await netInstance.load(contextOptions);
const loadTime = (performance.now() - loatStart).toFixed(2);
postMessage(loadTime);
Expand Down

0 comments on commit 41e4d21

Please sign in to comment.