Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Add AvgPool3D and MaxPool3D kernels #7294

Merged
merged 11 commits into from
Jan 24, 2023
25 changes: 25 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "pool3d_impl",
hdrs = ["pool3d_impl.h"],
)

tfjs_cc_library(
name = "prelu_impl",
srcs = ["prelu_impl.cc"],
Expand Down Expand Up @@ -302,6 +307,7 @@ tfjs_cc_library(
":ArgMax",
":Atan",
":AvgPool",
":AvgPool3D",
":BatchMatMul",
":Ceil",
":ClipByValue",
Expand Down Expand Up @@ -332,6 +338,7 @@ tfjs_cc_library(
":LessEqual",
":Max",
":MaxPool",
":MaxPool3D",
":Maximum",
":Min",
":Minimum",
Expand Down Expand Up @@ -491,6 +498,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "AvgPool3D",
srcs = ["kernels/AvgPool3D.cc"],
deps = [
":backend",
":pool3d_impl",
],
)

tfjs_unit_test(
name = "AvgPool_test",
srcs = ["kernels/AvgPool_test.cc"],
Expand Down Expand Up @@ -950,6 +966,15 @@ tfjs_unit_test(
],
)

tfjs_cc_library(
name = "MaxPool3D",
srcs = ["kernels/MaxPool3D.cc"],
deps = [
":backend",
":pool3d_impl",
],
)

tfjs_cc_library(
name = "Min",
srcs = ["kernels/Min.cc"],
Expand Down
91 changes: 91 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/AvgPool3D.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <algorithm>
#include <utility>

#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/pool3d_impl.h"

namespace tfjs::wasm {

// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

// REQUIRES:
// - Tensor `x` and `out` must have dtype float32 (checked in tfjs-core)
// - Tensor `x` and `out` must have data format 'NDHWC' (checked in tfjs-core)
void AvgPool3D(int x_id, int out_id, int batch_size, int channel_size,
int in_depth, int in_height, int in_width, int out_depth,
int out_height, int out_width, int stride_depth,
int stride_height, int stride_width, int dilation_depth,
int dilation_height, int dilation_width,
int effective_filter_depth, int effective_filter_height,
int effective_filter_width, int pad_front, int pad_top,
int pad_left) {
const TensorInfo& x_info = backend::get_tensor_info(x_id);
TensorInfo& out_info = backend::get_tensor_info_out(out_id);

NDHWCPool3DImpl(x_info.f32(), out_info.f32_write(),
NDHWCPool3DInfo{
.batch_size = batch_size,
.channel_size = channel_size,
.in_depth = in_depth,
.in_height = in_height,
.in_width = in_width,
.out_depth = out_depth,
.out_height = out_height,
.out_width = out_width,
.stride_depth = stride_depth,
.stride_height = stride_height,
.stride_width = stride_width,
.dilation_depth = dilation_depth,
.dilation_height = dilation_height,
.dilation_width = dilation_width,
.effective_filter_depth = effective_filter_depth,
.effective_filter_height = effective_filter_height,
.effective_filter_width = effective_filter_width,
.pad_front = pad_front,
.pad_top = pad_top,
.pad_left = pad_left,
},
/*filter_init=*/
[]() -> std::pair<float, int> {
return {0.0, 0};
},
/*filter_apply=*/
[](std::pair<float, int>& data, const float& val) {
data.first += val;
++data.second;
},
/*filter_aggregate=*/
[](const std::pair<float, int>& data) {
return data.first /
static_cast<float>(std::max(data.second, 1));
});
}

} // extern "C"
} // namespace tfjs::wasm
84 changes: 84 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/MaxPool3D.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <algorithm>
#include <limits>

#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/pool3d_impl.h"

namespace tfjs::wasm {

// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

// REQUIRES:
// - Tensor `x` and `out` must have dtype float32 (checked in tfjs-core)
// - Tensor `x` and `out` must have data format 'NDHWC' (checked in tfjs-core)
void MaxPool3D(int x_id, int out_id, int batch_size, int channel_size,
int in_depth, int in_height, int in_width, int out_depth,
int out_height, int out_width, int stride_depth,
int stride_height, int stride_width, int dilation_depth,
int dilation_height, int dilation_width,
int effective_filter_depth, int effective_filter_height,
int effective_filter_width, int pad_front, int pad_top,
int pad_left) {
const TensorInfo& x_info = backend::get_tensor_info(x_id);
TensorInfo& out_info = backend::get_tensor_info_out(out_id);

NDHWCPool3DImpl(
x_info.f32(), out_info.f32_write(),
NDHWCPool3DInfo{
.batch_size = batch_size,
.channel_size = channel_size,
.in_depth = in_depth,
.in_height = in_height,
.in_width = in_width,
.out_depth = out_depth,
.out_height = out_height,
.out_width = out_width,
.stride_depth = stride_depth,
.stride_height = stride_height,
.stride_width = stride_width,
.dilation_depth = dilation_depth,
.dilation_height = dilation_height,
.dilation_width = dilation_width,
.effective_filter_depth = effective_filter_depth,
.effective_filter_height = effective_filter_height,
.effective_filter_width = effective_filter_width,
.pad_front = pad_front,
.pad_top = pad_top,
.pad_left = pad_left,
},
/*filter_init=*/
[]() -> float { return std::numeric_limits<float>::min(); },
/*filter_apply=*/
[](float& data, const float& val) { data = std::max(data, val); },
/*filter_aggregate=*/
[](const float& data) { return data; });
}

} // extern "C"
} // namespace tfjs::wasm
128 changes: 128 additions & 0 deletions tfjs-backend-wasm/src/cc/pool3d_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cstddef>

namespace tfjs::wasm {

namespace {

inline int AddUntilNonNegative(int v, int d) {
if (v >= 0) {
return v;
}
return (v % d + d) % d;
}

} // namespace

struct NDHWCPool3DInfo {
int batch_size;
// Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in
// channels should always equal to out channels.
int channel_size;
int in_depth;
int in_height;
int in_width;
int out_depth;
int out_height;
int out_width;

int stride_depth;
int stride_height;
int stride_width;
int dilation_depth;
int dilation_height;
int dilation_width;
int effective_filter_depth;
int effective_filter_height;
int effective_filter_width;
int pad_front;
int pad_top;
int pad_left;

inline int in_offset(int b, int d, int h, int w, int c) const {
return c +
(w + (h + (d + b * in_depth) * in_height) * in_width) * channel_size;
}
inline int out_offset(int b, int d, int h, int w, int c) const {
return c + (w + (h + (d + b * out_depth) * out_height) * out_width) *
channel_size;
}
inline int in_size() const {
return batch_size * in_depth * in_height * in_width * channel_size;
}
inline int out_size() const {
return batch_size * out_depth * out_height * out_width * channel_size;
}
};
template <typename IN, typename OUT, typename FI, typename FAP, typename FAG>
inline void NDHWCPool3DImpl(const IN* x_buf, OUT* out_buf,
const NDHWCPool3DInfo& info, const FI& filter_init,
const FAP& filter_apply,
const FAG& filter_aggregate) {
for (int batch = 0; batch < info.batch_size; ++batch) {
for (int channel = 0; channel < info.channel_size; ++channel) {
for (int y_depth = 0; y_depth < info.out_depth; ++y_depth) {
int x_depth_corner = y_depth * info.stride_depth - info.pad_front;
int x_depth_min =
AddUntilNonNegative(x_depth_corner, info.dilation_depth);
int x_depth_max = std::min(
info.in_depth, info.effective_filter_depth + x_depth_corner);

for (int y_row = 0; y_row < info.out_height; ++y_row) {
int x_row_corner = y_row * info.stride_height - info.pad_top;
int x_row_min =
AddUntilNonNegative(x_row_corner, info.dilation_height);
int x_row_max = std::min(info.in_height,
info.effective_filter_height + x_row_corner);
for (int y_col = 0; y_col < info.out_width; ++y_col) {
int x_col_corner = y_col * info.stride_width - info.pad_left;
int x_col_min =
AddUntilNonNegative(x_col_corner, info.dilation_width);
int x_col_max = std::min(
info.in_width, info.effective_filter_width + x_col_corner);

// Apply the filter
auto filter_data = filter_init();
for (int x_depth = x_depth_min; x_depth < x_depth_max;
x_depth += info.dilation_depth) {
for (int x_row = x_row_min; x_row < x_row_max;
x_row += info.dilation_height) {
for (int x_col = x_col_min; x_col < x_col_max;
x_col += info.dilation_width) {
int x_offset =
info.in_offset(batch, x_depth, x_row, x_col, channel);
filter_apply(filter_data, x_buf[x_offset]);
}
}
}
int out_offset =
info.out_offset(batch, y_depth, y_row, y_col, channel);
out_buf[out_offset] = filter_aggregate(filter_data);
}
}
}
}
}
}

} // namespace tfjs::wasm
Loading