diff --git a/tfjs-backend-wasm/src/cc/BUILD.bazel b/tfjs-backend-wasm/src/cc/BUILD.bazel index 48b423f7754..2ab146dc8c8 100644 --- a/tfjs-backend-wasm/src/cc/BUILD.bazel +++ b/tfjs-backend-wasm/src/cc/BUILD.bazel @@ -661,6 +661,15 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "Diag", + srcs = ["kernels/Diag.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "RealDiv", srcs = ["kernels/RealDiv.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Diag.cc b/tfjs-backend-wasm/src/cc/kernels/Diag.cc new file mode 100644 index 00000000000..99cf94d0641 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Diag.cc @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/util.h" + +namespace tfjs { +namespace wasm { + +namespace { + +template +inline void DiagImpl(const T* x_buf, int32_t x_size, T* out_buf) { + std::fill(out_buf, out_buf + x_size * x_size, 0); + for (int32_t i = 0; i < x_size; ++i) { + out_buf[x_size * i + i] = x_buf[i]; + } +} + +} // namespace + +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Diag(const int32_t x_id, const DType dtype, const int32_t x_size, + const int32_t out_id) { + const TensorInfo& x_info = backend::get_tensor_info(x_id); + TensorInfo& out_info = backend::get_tensor_info_out(out_id); + switch (dtype) { + case DType::float32: + DiagImpl(x_info.f32(), x_size, out_info.f32_write()); + break; + case DType::int32: + DiagImpl(x_info.i32(), x_size, out_info.i32_write()); + break; + case DType::boolean: + DiagImpl(x_info.b(), x_size, out_info.b_write()); + break; + default: + util::warn("Diag for tensor id %d failed. Unsupported dtype %d", x_id, + dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Diag.ts b/tfjs-backend-wasm/src/kernels/Diag.ts new file mode 100644 index 00000000000..c1266d1389a --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Diag.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Diag, DiagInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +import {CppDType} from './types'; + +let wasmDiag: (xId: number, xDType: CppDType, xSize: number, outId: number) => + void; + +function setup(backend: BackendWasm) { + wasmDiag = backend.wasm.cwrap('Diag', null, [ + 'number', // xId + 'number', // xDType, + 'number', // xSize, + 'number', // outId + ]); +} + +export function diag(args: {inputs: DiagInputs, backend: BackendWasm}): + TensorInfo { + const {inputs, backend} = args; + const {x} = inputs; + + const xSize = util.sizeFromShape(x.shape); + const out = backend.makeOutput([...x.shape, ...x.shape], x.dtype); + + wasmDiag( + backend.dataIdMap.get(x.dataId).id, CppDType[x.dtype], xSize, + backend.dataIdMap.get(out.dataId).id); + return out; +} + +export const diagConfig: KernelConfig = { + kernelName: Diag, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: diag as unknown as KernelFunc +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index cc519fce755..74ca2dd8e22 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -47,6 +47,7 @@ import {cumprodConfig} from './kernels/Cumprod'; import {cumsumConfig} from './kernels/Cumsum'; import {depthToSpaceConfig} from './kernels/DepthToSpace'; import {depthwiseConv2dNativeConfig} from './kernels/DepthwiseConv2dNative'; +import {diagConfig} from './kernels/Diag'; import {eluConfig} from './kernels/Elu'; import {equalConfig} from './kernels/Equal'; import {expConfig} from './kernels/Exp'; @@ -166,6 +167,7 @@ const kernelConfigs: KernelConfig[] = [ cumsumConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, + diagConfig, eluConfig, equalConfig, expConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 34e975ebea2..11cde009c6a 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -407,6 +407,7 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'acosh '}, {include: 'asin '}, {include: 'asinh '}, + {include: 'diag '}, ]; const customInclude = (testName: string) => {