diff --git a/tfjs-backend-wasm/src/cc/BUILD.bazel b/tfjs-backend-wasm/src/cc/BUILD.bazel index d19d41b7fa9..9cf54ca5ff8 100644 --- a/tfjs-backend-wasm/src/cc/BUILD.bazel +++ b/tfjs-backend-wasm/src/cc/BUILD.bazel @@ -245,6 +245,11 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "pool3d_impl", + hdrs = ["pool3d_impl.h"], +) + tfjs_cc_library( name = "prelu_impl", srcs = ["prelu_impl.cc"], @@ -302,6 +307,7 @@ tfjs_cc_library( ":ArgMax", ":Atan", ":AvgPool", + ":AvgPool3D", ":BatchMatMul", ":Ceil", ":ClipByValue", @@ -332,6 +338,7 @@ tfjs_cc_library( ":LessEqual", ":Max", ":MaxPool", + ":MaxPool3D", ":Maximum", ":Min", ":Minimum", @@ -491,6 +498,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "AvgPool3D", + srcs = ["kernels/AvgPool3D.cc"], + deps = [ + ":backend", + ":pool3d_impl", + ], +) + tfjs_unit_test( name = "AvgPool_test", srcs = ["kernels/AvgPool_test.cc"], @@ -950,6 +966,15 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "MaxPool3D", + srcs = ["kernels/MaxPool3D.cc"], + deps = [ + ":backend", + ":pool3d_impl", + ], +) + tfjs_cc_library( name = "Min", srcs = ["kernels/Min.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/AvgPool3D.cc b/tfjs-backend-wasm/src/cc/kernels/AvgPool3D.cc new file mode 100644 index 00000000000..44e35047fe7 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/AvgPool3D.cc @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/pool3d_impl.h" + +namespace tfjs::wasm { + +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +// REQUIRES: +// - Tensor `x` and `out` must have dtype float32 (checked in tfjs-core) +// - Tensor `x` and `out` must have data format 'NDHWC' (checked in tfjs-core) +void AvgPool3D(int x_id, int out_id, int batch_size, int channel_size, + int in_depth, int in_height, int in_width, int out_depth, + int out_height, int out_width, int stride_depth, + int stride_height, int stride_width, int dilation_depth, + int dilation_height, int dilation_width, + int effective_filter_depth, int effective_filter_height, + int effective_filter_width, int pad_front, int pad_top, + int pad_left) { + const TensorInfo& x_info = backend::get_tensor_info(x_id); + TensorInfo& out_info = backend::get_tensor_info_out(out_id); + + NDHWCPool3DImpl(x_info.f32(), out_info.f32_write(), + NDHWCPool3DInfo{ + .batch_size = batch_size, + .channel_size = channel_size, + .in_depth = in_depth, + .in_height = in_height, + .in_width = in_width, + .out_depth = out_depth, + .out_height = out_height, + .out_width = out_width, + .stride_depth = stride_depth, + .stride_height = stride_height, + .stride_width = stride_width, + .dilation_depth = dilation_depth, + .dilation_height = dilation_height, + .dilation_width = dilation_width, + .effective_filter_depth = effective_filter_depth, + .effective_filter_height = effective_filter_height, + .effective_filter_width = effective_filter_width, + .pad_front = pad_front, + .pad_top = pad_top, + .pad_left = pad_left, + }, + /*filter_init=*/ + []() -> std::pair { + return {0.0, 0}; + }, + /*filter_apply=*/ + [](std::pair& data, const float& val) { + data.first += val; + ++data.second; + }, + /*filter_aggregate=*/ + [](const std::pair& data) { + return data.first / + static_cast(std::max(data.second, 1)); + }); +} + +} // extern "C" +} // namespace tfjs::wasm diff --git a/tfjs-backend-wasm/src/cc/kernels/MaxPool3D.cc b/tfjs-backend-wasm/src/cc/kernels/MaxPool3D.cc new file mode 100644 index 00000000000..49c841012fe --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/MaxPool3D.cc @@ -0,0 +1,84 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/pool3d_impl.h" + +namespace tfjs::wasm { + +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +// REQUIRES: +// - Tensor `x` and `out` must have dtype float32 (checked in tfjs-core) +// - Tensor `x` and `out` must have data format 'NDHWC' (checked in tfjs-core) +void MaxPool3D(int x_id, int out_id, int batch_size, int channel_size, + int in_depth, int in_height, int in_width, int out_depth, + int out_height, int out_width, int stride_depth, + int stride_height, int stride_width, int dilation_depth, + int dilation_height, int dilation_width, + int effective_filter_depth, int effective_filter_height, + int effective_filter_width, int pad_front, int pad_top, + int pad_left) { + const TensorInfo& x_info = backend::get_tensor_info(x_id); + TensorInfo& out_info = backend::get_tensor_info_out(out_id); + + NDHWCPool3DImpl( + x_info.f32(), out_info.f32_write(), + NDHWCPool3DInfo{ + .batch_size = batch_size, + .channel_size = channel_size, + .in_depth = in_depth, + .in_height = in_height, + .in_width = in_width, + .out_depth = out_depth, + .out_height = out_height, + .out_width = out_width, + .stride_depth = stride_depth, + .stride_height = stride_height, + .stride_width = stride_width, + .dilation_depth = dilation_depth, + .dilation_height = dilation_height, + .dilation_width = dilation_width, + .effective_filter_depth = effective_filter_depth, + .effective_filter_height = effective_filter_height, + .effective_filter_width = effective_filter_width, + .pad_front = pad_front, + .pad_top = pad_top, + .pad_left = pad_left, + }, + /*filter_init=*/ + []() -> float { return std::numeric_limits::min(); }, + /*filter_apply=*/ + [](float& data, const float& val) { data = std::max(data, val); }, + /*filter_aggregate=*/ + [](const float& data) { return data; }); +} + +} // extern "C" +} // namespace tfjs::wasm diff --git a/tfjs-backend-wasm/src/cc/pool3d_impl.h b/tfjs-backend-wasm/src/cc/pool3d_impl.h new file mode 100644 index 00000000000..143574675c2 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/pool3d_impl.h @@ -0,0 +1,128 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +namespace tfjs::wasm { + +namespace { + +inline int AddUntilNonNegative(int v, int d) { + if (v >= 0) { + return v; + } + return (v % d + d) % d; +} + +} // namespace + +struct NDHWCPool3DInfo { + int batch_size; + // Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in + // channels should always equal to out channels. + int channel_size; + int in_depth; + int in_height; + int in_width; + int out_depth; + int out_height; + int out_width; + + int stride_depth; + int stride_height; + int stride_width; + int dilation_depth; + int dilation_height; + int dilation_width; + int effective_filter_depth; + int effective_filter_height; + int effective_filter_width; + int pad_front; + int pad_top; + int pad_left; + + inline int in_offset(int b, int d, int h, int w, int c) const { + return c + + (w + (h + (d + b * in_depth) * in_height) * in_width) * channel_size; + } + inline int out_offset(int b, int d, int h, int w, int c) const { + return c + (w + (h + (d + b * out_depth) * out_height) * out_width) * + channel_size; + } + inline int in_size() const { + return batch_size * in_depth * in_height * in_width * channel_size; + } + inline int out_size() const { + return batch_size * out_depth * out_height * out_width * channel_size; + } +}; +template +inline void NDHWCPool3DImpl(const IN* x_buf, OUT* out_buf, + const NDHWCPool3DInfo& info, const FI& filter_init, + const FAP& filter_apply, + const FAG& filter_aggregate) { + for (int batch = 0; batch < info.batch_size; ++batch) { + for (int channel = 0; channel < info.channel_size; ++channel) { + for (int y_depth = 0; y_depth < info.out_depth; ++y_depth) { + int x_depth_corner = y_depth * info.stride_depth - info.pad_front; + int x_depth_min = + AddUntilNonNegative(x_depth_corner, info.dilation_depth); + int x_depth_max = std::min( + info.in_depth, info.effective_filter_depth + x_depth_corner); + + for (int y_row = 0; y_row < info.out_height; ++y_row) { + int x_row_corner = y_row * info.stride_height - info.pad_top; + int x_row_min = + AddUntilNonNegative(x_row_corner, info.dilation_height); + int x_row_max = std::min(info.in_height, + info.effective_filter_height + x_row_corner); + for (int y_col = 0; y_col < info.out_width; ++y_col) { + int x_col_corner = y_col * info.stride_width - info.pad_left; + int x_col_min = + AddUntilNonNegative(x_col_corner, info.dilation_width); + int x_col_max = std::min( + info.in_width, info.effective_filter_width + x_col_corner); + + // Apply the filter + auto filter_data = filter_init(); + for (int x_depth = x_depth_min; x_depth < x_depth_max; + x_depth += info.dilation_depth) { + for (int x_row = x_row_min; x_row < x_row_max; + x_row += info.dilation_height) { + for (int x_col = x_col_min; x_col < x_col_max; + x_col += info.dilation_width) { + int x_offset = + info.in_offset(batch, x_depth, x_row, x_col, channel); + filter_apply(filter_data, x_buf[x_offset]); + } + } + } + int out_offset = + info.out_offset(batch, y_depth, y_row, y_col, channel); + out_buf[out_offset] = filter_aggregate(filter_data); + } + } + } + } + } +} + +} // namespace tfjs::wasm diff --git a/tfjs-backend-wasm/src/kernels/AvgPool3D.ts b/tfjs-backend-wasm/src/kernels/AvgPool3D.ts new file mode 100644 index 00000000000..34d54ff6032 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/AvgPool3D.ts @@ -0,0 +1,106 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {AvgPool3D, AvgPool3DAttrs, AvgPool3DInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +let wasmAvgPool3D: ( + xId: number, outId: number, batchSize: number, channelSize: number, + inDepth: number, inHeight: number, inWidth: number, outDepth: number, + outHeight: number, outWidth: number, strideDepth: number, + strideHeight: number, strideWidth: number, dilationDepth: number, + dilationHeight: number, dilationWidth: number, effectiveFilterDepth: number, + effectiveFilterHeight: number, effectiveFilterWidth: number, + padFront: number, padTop: number, padLeft: number) => void; + +function setup(backend: BackendWasm) { + wasmAvgPool3D = backend.wasm.cwrap('AvgPool3D', null, [ + 'number', // xId + 'number', // outId + 'number', // batchSize + 'number', // channelSize + 'number', // inDepth + 'number', // inHeight + 'number', // inWidth + 'number', // outDepth + 'number', // outHeight + 'number', // outWidth + 'number', // strideDepth + 'number', // strideHeight + 'number', // strideWidth + 'number', // dilationDepth + 'number', // dilationHeight + 'number', // dilationWidth + 'number', // effectiveFilterDepth + 'number', // effectiveFilterHeight + 'number', // effectiveFilterWidth + 'number', // padFront + 'number', // padTop + 'number', // padLeft + ]); +} + +export function avgPool3D(args: { + inputs: AvgPool3DInputs, + attrs: AvgPool3DAttrs, + backend: BackendWasm, +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {filterSize, strides, pad, dimRoundingMode, dataFormat} = attrs; + + const convInfo = backend_util.computePool3DInfo( + x.shape as [number, number, number, number, number], filterSize, strides, + /*dilations=*/1, pad, dimRoundingMode, dataFormat); + const out = backend.makeOutput(convInfo.outShape, x.dtype); + + wasmAvgPool3D( + backend.dataIdMap.get(x.dataId).id, + backend.dataIdMap.get(out.dataId).id, + convInfo.batchSize, + // Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in + // channels should always equal to out channels. + /*channelSize=*/convInfo.inChannels, + convInfo.inDepth, + convInfo.inHeight, + convInfo.inWidth, + convInfo.outDepth, + convInfo.outHeight, + convInfo.outWidth, + convInfo.strideDepth, + convInfo.strideHeight, + convInfo.strideWidth, + convInfo.dilationDepth, + convInfo.dilationHeight, + convInfo.dilationWidth, + convInfo.effectiveFilterDepth, + convInfo.effectiveFilterHeight, + convInfo.effectiveFilterWidth, + convInfo.padInfo.front, + convInfo.padInfo.top, + convInfo.padInfo.left, + ); + return out; +} + +export const avgPool3DConfig: KernelConfig = { + kernelName: AvgPool3D, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: avgPool3D as unknown as KernelFunc +}; diff --git a/tfjs-backend-wasm/src/kernels/MaxPool3D.ts b/tfjs-backend-wasm/src/kernels/MaxPool3D.ts new file mode 100644 index 00000000000..5ab51b9b7b7 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/MaxPool3D.ts @@ -0,0 +1,106 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, KernelConfig, KernelFunc, MaxPool3D, MaxPool3DAttrs, MaxPool3DInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +let wasmMaxPool3D: ( + xId: number, outId: number, batchSize: number, channelSize: number, + inDepth: number, inHeight: number, inWidth: number, outDepth: number, + outHeight: number, outWidth: number, strideDepth: number, + strideHeight: number, strideWidth: number, dilationDepth: number, + dilationHeight: number, dilationWidth: number, effectiveFilterDepth: number, + effectiveFilterHeight: number, effectiveFilterWidth: number, + padFront: number, padTop: number, padLeft: number) => void; + +function setup(backend: BackendWasm) { + wasmMaxPool3D = backend.wasm.cwrap('MaxPool3D', null, [ + 'number', // xId + 'number', // outId + 'number', // batchSize + 'number', // channelSize + 'number', // inDepth + 'number', // inHeight + 'number', // inWidth + 'number', // outDepth + 'number', // outHeight + 'number', // outWidth + 'number', // strideDepth + 'number', // strideHeight + 'number', // strideWidth + 'number', // dilationDepth + 'number', // dilationHeight + 'number', // dilationWidth + 'number', // effectiveFilterDepth + 'number', // effectiveFilterHeight + 'number', // effectiveFilterWidth + 'number', // padFront + 'number', // padTop + 'number', // padLeft + ]); +} + +export function maxPool3D(args: { + inputs: MaxPool3DInputs, + attrs: MaxPool3DAttrs, + backend: BackendWasm, +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {filterSize, strides, pad, dimRoundingMode, dataFormat} = attrs; + + const convInfo = backend_util.computePool3DInfo( + x.shape as [number, number, number, number, number], filterSize, strides, + /*dilations=*/1, pad, dimRoundingMode, dataFormat); + const out = backend.makeOutput(convInfo.outShape, x.dtype); + + wasmMaxPool3D( + backend.dataIdMap.get(x.dataId).id, + backend.dataIdMap.get(out.dataId).id, + convInfo.batchSize, + // Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in + // channels should always equal to out channels. + /*channelSize=*/convInfo.inChannels, + convInfo.inDepth, + convInfo.inHeight, + convInfo.inWidth, + convInfo.outDepth, + convInfo.outHeight, + convInfo.outWidth, + convInfo.strideDepth, + convInfo.strideHeight, + convInfo.strideWidth, + convInfo.dilationDepth, + convInfo.dilationHeight, + convInfo.dilationWidth, + convInfo.effectiveFilterDepth, + convInfo.effectiveFilterHeight, + convInfo.effectiveFilterWidth, + convInfo.padInfo.front, + convInfo.padInfo.top, + convInfo.padInfo.left, + ); + return out; +} + +export const maxPool3DConfig: KernelConfig = { + kernelName: MaxPool3D, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: maxPool3D as unknown as KernelFunc +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 2e6b3056f8c..ef710368c06 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -32,6 +32,7 @@ import {asinConfig} from './kernels/Asin'; import {asinhConfig} from './kernels/Asinh'; import {atanConfig} from './kernels/Atan'; import {avgPoolConfig} from './kernels/AvgPool'; +import {avgPool3DConfig} from './kernels/AvgPool3D'; import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; import {broadcastArgsConfig} from './kernels/BroadcastArgs'; @@ -78,6 +79,7 @@ import {logicalXorConfig} from './kernels/LogicalXor'; import {maxConfig} from './kernels/Max'; import {maximumConfig} from './kernels/Maximum'; import {maxPoolConfig} from './kernels/MaxPool'; +import {maxPool3DConfig} from './kernels/MaxPool3D'; import {meanConfig} from './kernels/Mean'; import {minConfig} from './kernels/Min'; import {minimumConfig} from './kernels/Minimum'; @@ -154,6 +156,7 @@ const kernelConfigs: KernelConfig[] = [ asinhConfig, atanConfig, avgPoolConfig, + avgPool3DConfig, batchMatMulConfig, batchToSpaceNDConfig, broadcastArgsConfig, @@ -200,6 +203,7 @@ const kernelConfigs: KernelConfig[] = [ maxConfig, maximumConfig, maxPoolConfig, + maxPool3DConfig, meanConfig, minConfig, minimumConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index b212dc1475e..65c74d85f09 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -66,8 +66,7 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'avgPool', excludes: [ - 'gradient', // Not yet implemented. - 'avgPool3d', // Not yet implemented. + 'gradient', // Not yet implemented. ] }, { @@ -85,11 +84,9 @@ const TEST_FILTERS: TestFilter[] = [ include: 'maxPool', excludes: [ 'maxPoolBackprop', // Not yet implemented. - 'maxPool3d', // Not yet implemented. 'maxPool3dBackprop', // Not yet implemented. 'ignores NaNs', // Actual != expected. 'maxPoolWithArgmax' // Not yet implemented. - ] }, {include: 'cropAndResize'},