From ceb69e631c2e220daa13513a030bfe3dcdba537c Mon Sep 17 00:00:00 2001 From: chunnienc <121328115+chunnienc@users.noreply.github.com> Date: Mon, 23 Jan 2023 10:56:36 -0800 Subject: [PATCH] [wasm] Add BroadcastArgs kernel (#7290) * Add DenseBincount kernel * Remove redundant checks * Reduce buf reset calls * Update to C++17 * Add BroadcastArgs kernel * Fix lint * Update tfjs-backend-wasm/src/backend_wasm.ts Co-authored-by: Matthew Soulanille Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com> Co-authored-by: Matthew Soulanille --- tfjs-backend-wasm/src/backend_wasm.ts | 72 ++++++++++--------- .../src/kernels/BroadcastArgs.ts | 43 +++++++++++ tfjs-backend-wasm/src/register_all_kernels.ts | 2 + tfjs-backend-wasm/src/setup_test.ts | 1 + 4 files changed, 84 insertions(+), 34 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/BroadcastArgs.ts diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index f56bcbf03c5..9c16f688d37 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -20,7 +20,7 @@ import {backend_util, BackendTimingInfo, DataStorage, DataType, deprecationWarn, import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm'; import {BackendWasmThreadedSimdModule} from '../wasm-out/tfjs-backend-wasm-threaded-simd'; -import * as wasmFactoryThreadedSimd_import from '../wasm-out/tfjs-backend-wasm-threaded-simd.js'; +import * as wasmFactoryThreadedSimd_import from '../wasm-out/tfjs-backend-wasm-threaded-simd.js'; // @ts-ignore import {wasmWorkerContents} from '../wasm-out/tfjs-backend-wasm-threaded-simd.worker.js'; import * as wasmFactory_import from '../wasm-out/tfjs-backend-wasm.js'; @@ -29,11 +29,11 @@ import * as wasmFactory_import from '../wasm-out/tfjs-backend-wasm.js'; // the node bundle (for testing). This would not be necessary if we // flipped esModuleInterop to true, but we likely can't do that since // google3 does not use it. -const wasmFactoryThreadedSimd = (wasmFactoryThreadedSimd_import.default - || wasmFactoryThreadedSimd_import) as -typeof wasmFactoryThreadedSimd_import.default; -const wasmFactory = (wasmFactory_import.default - || wasmFactory_import) as typeof wasmFactory_import.default; +const wasmFactoryThreadedSimd = (wasmFactoryThreadedSimd_import.default || + wasmFactoryThreadedSimd_import) as + typeof wasmFactoryThreadedSimd_import.default; +const wasmFactory = (wasmFactory_import.default || wasmFactory_import) as + typeof wasmFactory_import.default; interface TensorData { id: number; @@ -59,7 +59,8 @@ export class BackendWasm extends KernelBackend { this.dataIdMap = new DataStorage(this, engine()); } - override write(values: backend_util.BackendValues, shape: number[], + override write( + values: backend_util.BackendValues|null, shape: number[], dtype: DataType): DataId { const dataId = {id: this.dataIdNextNumber++}; this.move(dataId, values, shape, dtype, 1); @@ -78,7 +79,7 @@ export class BackendWasm extends KernelBackend { } override move( - dataId: DataId, values: backend_util.BackendValues, shape: number[], + dataId: DataId, values: backend_util.BackendValues|null, shape: number[], dtype: DataType, refCount: number): void { const id = this.dataIdNextNumber++; if (dtype === 'string') { @@ -196,11 +197,12 @@ export class BackendWasm extends KernelBackend { * is present, the memory was allocated elsewhere (in c++) and we just record * the pointer where that memory lives. */ - makeOutput(shape: number[], dtype: DataType, memoryOffset?: number): - TensorInfo { + makeOutput( + shape: number[], dtype: DataType, memoryOffset?: number, + values?: backend_util.BackendValues): TensorInfo { let dataId: {}; if (memoryOffset == null) { - dataId = this.write(null /* values */, shape, dtype); + dataId = this.write(values ?? null, shape, dtype); } else { const id = this.dataIdNextNumber++; dataId = {id}; @@ -366,29 +368,31 @@ export async function init(): Promise<{wasm: BackendWasmModule}> { // failed fetch, result in this promise being rejected. These are // caught and re-rejected below. wasm.then((module) => { - initialized = true; - initAborted = false; - - const voidReturnType: string = null; - // Using the tfjs namespace to avoid conflict with emscripten's API. - module.tfjs = { - init: module.cwrap('init', null, []), - initWithThreadsCount: - module.cwrap('init_with_threads_count', null, ['number']), - getThreadsCount: module.cwrap('get_threads_count', 'number', []), - registerTensor: module.cwrap( - 'register_tensor', null, - [ - 'number', // id - 'number', // size - 'number', // memoryOffset - ]), - disposeData: module.cwrap('dispose_data', voidReturnType, ['number']), - dispose: module.cwrap('dispose', voidReturnType, []), - }; - - resolve({wasm: module}); - }).catch(reject); + initialized = true; + initAborted = false; + + const voidReturnType: string = null; + // Using the tfjs namespace to avoid conflict with emscripten's API. + module.tfjs = { + init: module.cwrap('init', null, []), + initWithThreadsCount: + module.cwrap('init_with_threads_count', null, ['number']), + getThreadsCount: module.cwrap('get_threads_count', 'number', []), + registerTensor: module.cwrap( + 'register_tensor', null, + [ + 'number', // id + 'number', // size + 'number', // memoryOffset + ]), + disposeData: + module.cwrap('dispose_data', voidReturnType, ['number']), + dispose: module.cwrap('dispose', voidReturnType, []), + }; + + resolve({wasm: module}); + }) + .catch(reject); }); } diff --git a/tfjs-backend-wasm/src/kernels/BroadcastArgs.ts b/tfjs-backend-wasm/src/kernels/BroadcastArgs.ts new file mode 100644 index 00000000000..d98bb8713ca --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/BroadcastArgs.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {backend_util, BroadcastArgs, BroadcastArgsInputs, KernelConfig, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function broadcastArgs(args: { + inputs: BroadcastArgsInputs, + backend: BackendWasm, +}): TensorInfo { + const {inputs, backend} = args; + const {s0, s1} = inputs; + + const s0Vals = backend.typedArrayFromHeap(s0); + const s1Vals = backend.typedArrayFromHeap(s1); + + const broadcastShape = backend_util.assertAndGetBroadcastShape( + Array.from(s0Vals), Array.from(s1Vals)); + + return backend.makeOutput( + [broadcastShape.length], 'int32', /*memoryOffset=*/undefined, + /*values=*/new Int32Array(broadcastShape)); +} + +export const broadcastArgsConfig: KernelConfig = { + kernelName: BroadcastArgs, + backendName: 'wasm', + kernelFunc: broadcastArgs +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index d9017380aa2..2e6b3056f8c 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -34,6 +34,7 @@ import {atanConfig} from './kernels/Atan'; import {avgPoolConfig} from './kernels/AvgPool'; import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; +import {broadcastArgsConfig} from './kernels/BroadcastArgs'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; import {clipByValueConfig} from './kernels/ClipByValue'; @@ -155,6 +156,7 @@ const kernelConfigs: KernelConfig[] = [ avgPoolConfig, batchMatMulConfig, batchToSpaceNDConfig, + broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 0726b2f2bb4..b212dc1475e 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -409,6 +409,7 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'asinh '}, {include: 'diag '}, {include: 'denseBincount '}, + {include: 'broadcastArgs '}, ]; const customInclude = (testName: string) => {