Skip to content

Commit

Permalink
[wasm] Add BroadcastArgs kernel (#7290)
Browse files Browse the repository at this point in the history
* Add DenseBincount kernel

* Remove redundant checks

* Reduce buf reset calls

* Update to C++17

* Add BroadcastArgs kernel

* Fix lint

* Update tfjs-backend-wasm/src/backend_wasm.ts

Co-authored-by: Matthew Soulanille <[email protected]>

Co-authored-by: Ping Yu <[email protected]>
Co-authored-by: Matthew Soulanille <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2023
1 parent 8704ae4 commit ceb69e6
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 34 deletions.
72 changes: 38 additions & 34 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {backend_util, BackendTimingInfo, DataStorage, DataType, deprecationWarn,

import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm';
import {BackendWasmThreadedSimdModule} from '../wasm-out/tfjs-backend-wasm-threaded-simd';
import * as wasmFactoryThreadedSimd_import from '../wasm-out/tfjs-backend-wasm-threaded-simd.js';
import * as wasmFactoryThreadedSimd_import from '../wasm-out/tfjs-backend-wasm-threaded-simd.js';
// @ts-ignore
import {wasmWorkerContents} from '../wasm-out/tfjs-backend-wasm-threaded-simd.worker.js';
import * as wasmFactory_import from '../wasm-out/tfjs-backend-wasm.js';
Expand All @@ -29,11 +29,11 @@ import * as wasmFactory_import from '../wasm-out/tfjs-backend-wasm.js';
// the node bundle (for testing). This would not be necessary if we
// flipped esModuleInterop to true, but we likely can't do that since
// google3 does not use it.
const wasmFactoryThreadedSimd = (wasmFactoryThreadedSimd_import.default
|| wasmFactoryThreadedSimd_import) as
typeof wasmFactoryThreadedSimd_import.default;
const wasmFactory = (wasmFactory_import.default
|| wasmFactory_import) as typeof wasmFactory_import.default;
const wasmFactoryThreadedSimd = (wasmFactoryThreadedSimd_import.default ||
wasmFactoryThreadedSimd_import) as
typeof wasmFactoryThreadedSimd_import.default;
const wasmFactory = (wasmFactory_import.default || wasmFactory_import) as
typeof wasmFactory_import.default;

interface TensorData {
id: number;
Expand All @@ -59,7 +59,8 @@ export class BackendWasm extends KernelBackend {
this.dataIdMap = new DataStorage(this, engine());
}

override write(values: backend_util.BackendValues, shape: number[],
override write(
values: backend_util.BackendValues|null, shape: number[],
dtype: DataType): DataId {
const dataId = {id: this.dataIdNextNumber++};
this.move(dataId, values, shape, dtype, 1);
Expand All @@ -78,7 +79,7 @@ export class BackendWasm extends KernelBackend {
}

override move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dataId: DataId, values: backend_util.BackendValues|null, shape: number[],
dtype: DataType, refCount: number): void {
const id = this.dataIdNextNumber++;
if (dtype === 'string') {
Expand Down Expand Up @@ -196,11 +197,12 @@ export class BackendWasm extends KernelBackend {
* is present, the memory was allocated elsewhere (in c++) and we just record
* the pointer where that memory lives.
*/
makeOutput(shape: number[], dtype: DataType, memoryOffset?: number):
TensorInfo {
makeOutput(
shape: number[], dtype: DataType, memoryOffset?: number,
values?: backend_util.BackendValues): TensorInfo {
let dataId: {};
if (memoryOffset == null) {
dataId = this.write(null /* values */, shape, dtype);
dataId = this.write(values ?? null, shape, dtype);
} else {
const id = this.dataIdNextNumber++;
dataId = {id};
Expand Down Expand Up @@ -366,29 +368,31 @@ export async function init(): Promise<{wasm: BackendWasmModule}> {
// failed fetch, result in this promise being rejected. These are
// caught and re-rejected below.
wasm.then((module) => {
initialized = true;
initAborted = false;

const voidReturnType: string = null;
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
initWithThreadsCount:
module.cwrap('init_with_threads_count', null, ['number']),
getThreadsCount: module.cwrap('get_threads_count', 'number', []),
registerTensor: module.cwrap(
'register_tensor', null,
[
'number', // id
'number', // size
'number', // memoryOffset
]),
disposeData: module.cwrap('dispose_data', voidReturnType, ['number']),
dispose: module.cwrap('dispose', voidReturnType, []),
};

resolve({wasm: module});
}).catch(reject);
initialized = true;
initAborted = false;

const voidReturnType: string = null;
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
initWithThreadsCount:
module.cwrap('init_with_threads_count', null, ['number']),
getThreadsCount: module.cwrap('get_threads_count', 'number', []),
registerTensor: module.cwrap(
'register_tensor', null,
[
'number', // id
'number', // size
'number', // memoryOffset
]),
disposeData:
module.cwrap('dispose_data', voidReturnType, ['number']),
dispose: module.cwrap('dispose', voidReturnType, []),
};

resolve({wasm: module});
})
.catch(reject);
});
}

Expand Down
43 changes: 43 additions & 0 deletions tfjs-backend-wasm/src/kernels/BroadcastArgs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {backend_util, BroadcastArgs, BroadcastArgsInputs, KernelConfig, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

export function broadcastArgs(args: {
inputs: BroadcastArgsInputs,
backend: BackendWasm,
}): TensorInfo {
const {inputs, backend} = args;
const {s0, s1} = inputs;

const s0Vals = backend.typedArrayFromHeap(s0);
const s1Vals = backend.typedArrayFromHeap(s1);

const broadcastShape = backend_util.assertAndGetBroadcastShape(
Array.from(s0Vals), Array.from(s1Vals));

return backend.makeOutput(
[broadcastShape.length], 'int32', /*memoryOffset=*/undefined,
/*values=*/new Int32Array(broadcastShape));
}

export const broadcastArgsConfig: KernelConfig = {
kernelName: BroadcastArgs,
backendName: 'wasm',
kernelFunc: broadcastArgs
};
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {atanConfig} from './kernels/Atan';
import {avgPoolConfig} from './kernels/AvgPool';
import {batchMatMulConfig} from './kernels/BatchMatMul';
import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND';
import {broadcastArgsConfig} from './kernels/BroadcastArgs';
import {castConfig} from './kernels/Cast';
import {ceilConfig} from './kernels/Ceil';
import {clipByValueConfig} from './kernels/ClipByValue';
Expand Down Expand Up @@ -155,6 +156,7 @@ const kernelConfigs: KernelConfig[] = [
avgPoolConfig,
batchMatMulConfig,
batchToSpaceNDConfig,
broadcastArgsConfig,
castConfig,
ceilConfig,
clipByValueConfig,
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ const TEST_FILTERS: TestFilter[] = [
{include: 'asinh '},
{include: 'diag '},
{include: 'denseBincount '},
{include: 'broadcastArgs '},
];

const customInclude = (testName: string) => {
Expand Down

0 comments on commit ceb69e6

Please sign in to comment.