Skip to content

Commit

Permalink
enable initialization of webgpu
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 3fb2712 commit ba09337
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 21 deletions.
26 changes: 16 additions & 10 deletions js/web/karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined;
const karmaPlugins = require('minimist')(process.argv)['karma-plugins'] || undefined;
const timeoutMocha = require('minimist')(process.argv)['timeout-mocha'] || 60000;
const forceLocalHost = !!require('minimist')(process.argv)['force-localhost'];
const commonFile = bundleMode === 'dev' ? '../common/dist/ort-common.js' : '../common/dist/ort-common.min.js'
const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js';

Expand All @@ -16,18 +17,20 @@ const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js';
// https://stackoverflow.com/a/8440736
//
function getMachineIpAddress() {
var os = require('os');
var ifaces = os.networkInterfaces();
if (!forceLocalHost) {
var os = require('os');
var ifaces = os.networkInterfaces();

for (const ifname in ifaces) {
for (const iface of ifaces[ifname]) {
if ('IPv4' !== iface.family || iface.internal !== false) {
// skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses
continue;
}
for (const ifname in ifaces) {
for (const iface of ifaces[ifname]) {
if ('IPv4' !== iface.family || iface.internal !== false) {
// skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses
continue;
}

// returns the first available IP address
return iface.address;
// returns the first available IP address
return iface.address;
}
}
}

Expand Down Expand Up @@ -80,6 +83,9 @@ module.exports = function (config) {
ChromeTest: { base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer'] },
ChromePerf: { base: 'Chrome', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] },
ChromeDebug: { debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] },
ChromeCanaryTest: { base: 'ChromeCanaryHeadless', flags: ['--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] },
ChromeCanaryPerf: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] },
ChromeCanaryDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] },

//
// ==== BrowserStack browsers ====
Expand Down
18 changes: 15 additions & 3 deletions js/web/lib/onnxjs/backends/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,23 @@ import {Session} from '../session';
import {WebGpuSessionHandler} from './webgpu/session-handler';

export class WebGpuBackend implements Backend {
initialize(): boolean {
device: GPUDevice;
async initialize(): Promise<boolean> {
try {
// STEP.1 TODO: set up context (one time initialization)
if (!navigator.gpu) {
// WebGPU is not available.
Logger.warning('WebGpuBackend', 'WebGPU is not available.');
return false;
}

// STEP.2 TODO: set up flags
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
Logger.warning('WebGpuBackend', 'Failed to get GPU adapter.');
return false;
}
this.device = await adapter.requestDevice();

// TODO: set up flags

Logger.setWithEnv(env);

Expand Down
27 changes: 19 additions & 8 deletions js/web/script/test-runner-cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if (shouldLoadSuiteTestData) {

// The default backends and opset version lists. Those will be used in suite tests.
const DEFAULT_BACKENDS: readonly TestRunnerCliArgs.Backend[] =
args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl'];
args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu'];
const DEFAULT_OPSET_VERSIONS: readonly number[] = [13, 12, 11, 10, 9, 8, 7];

const FILE_CACHE_ENABLED = args.fileCache; // whether to enable file cache
Expand Down Expand Up @@ -454,11 +454,13 @@ function run(config: Test.Config) {
// STEP 5. use Karma to run test
npmlog.info('TestRunnerCli.Run', '(5/5) Running karma to start test runner...');
const karmaCommand = path.join(npmBin, 'karma');
const webgpu = args.backends.indexOf('webgpu') > -1;
const browser = getBrowserNameFromEnv(
args.env,
args.bundleMode === 'perf' ? 'perf' :
args.debug ? 'debug' :
'test');
'test',
webgpu);
const karmaArgs = ['start', `--browsers ${browser}`];
if (args.debug) {
karmaArgs.push('--log-level info --timeout-mocha 9999999');
Expand All @@ -468,6 +470,9 @@ function run(config: Test.Config) {
if (args.noSandbox) {
karmaArgs.push('--no-sandbox');
}
if (webgpu) {
karmaArgs.push('--force-localhost');
}
karmaArgs.push(`--bundle-mode=${args.bundleMode}`);
if (browser === 'Edge') {
// There are currently 2 Edge browser launchers:
Expand Down Expand Up @@ -559,10 +564,11 @@ function saveConfig(config: Test.Config) {
fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config);
}

function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test') {

function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean) {
switch (env) {
case 'chrome':
return selectChromeBrowser(mode);
return selectChromeBrowser(mode, webgpu);
case 'edge':
return 'Edge';
case 'firefox':
Expand All @@ -578,13 +584,18 @@ function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'per
}
}

function selectChromeBrowser(mode: 'debug'|'perf'|'test') {
function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean) {
let browserName = 'Chrome';
if (webgpu) {
browserName += 'Canary';
}

switch (mode) {
case 'debug':
return 'ChromeDebug';
return browserName + 'Debug';
case 'perf':
return 'ChromePerf';
return browserName + 'Perf';
default:
return 'ChromeTest';
return browserName + 'Test';
}
}

0 comments on commit ba09337

Please sign in to comment.