diff --git a/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc b/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc index 87968c79dfe..477f55d98c5 100644 --- a/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc +++ b/tfjs-backend-wasm/src/cc/kernels/AvgPool.cc @@ -12,9 +12,11 @@ * limitations under the License. * ===========================================================================*/ +#include #ifdef __EMSCRIPTEN__ #include #endif +#include #include #include @@ -54,6 +56,14 @@ void AvgPool(const size_t x_id, const size_t batch_size, const float* x_buf = reinterpret_cast(x_info.memory_offset); float* out_buf = reinterpret_cast(out_info.memory_offset); + // XNNPack does not support 1x1 filters for AvgPool + if (filter_width == 1 && filter_height == 1) { + tfjs::util::identity_pool(x_id, x_buf, out_buf, out_info.size, batch_size, + input_height, input_width, stride_height, + stride_width, channels); + return; + } + xnn_operator_t avg_pool_op = nullptr; const uint32_t flags = 0; diff --git a/tfjs-backend-wasm/src/cc/kernels/MaxPool.cc b/tfjs-backend-wasm/src/cc/kernels/MaxPool.cc index 23cad472ddf..841f273a09b 100644 --- a/tfjs-backend-wasm/src/cc/kernels/MaxPool.cc +++ b/tfjs-backend-wasm/src/cc/kernels/MaxPool.cc @@ -55,6 +55,14 @@ void MaxPool(const size_t x_id, const size_t batch_size, const float* x_buf = reinterpret_cast(x_info.memory_offset); float* out_buf = reinterpret_cast(out_info.memory_offset); + // XNNPack does not support 1x1 filters for MaxPool + if (filter_width == 1 && filter_height == 1) { + tfjs::util::identity_pool(x_id, x_buf, out_buf, out_info.size, batch_size, + input_height, input_width, stride_height, + stride_width, input_channels); + return; + } + xnn_operator_t max_pool_op = nullptr; const uint32_t flags = 0; diff --git a/tfjs-backend-wasm/src/cc/util.cc b/tfjs-backend-wasm/src/cc/util.cc index 4816a71882a..4eec35e58ea 100644 --- a/tfjs-backend-wasm/src/cc/util.cc +++ b/tfjs-backend-wasm/src/cc/util.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include "tfjs-backend-wasm/src/cc/util.h" @@ -87,5 +88,53 @@ const std::vector get_broadcast_dims( return dims; } +const void identity_pool(const size_t x_id, const float* x_buf, float* out_buf, + const size_t out_size, const size_t batch_size, + const size_t input_height, const size_t input_width, + const size_t stride_height, const size_t stride_width, + const size_t channels) { + // Early bailout for the identity case to use memcpy for efficiency. + if (stride_width == 1 && stride_height == 1) { + std::memcpy(out_buf, x_buf, out_size * sizeof(*out_buf)); + return; + } + + // Values per row and column are determined by the stride size. + // ceil(input_height / stride_height) instead of floor because strides do + // not guarantee that more than one value is available. + // e.g. a stride of 3 would 'partition' range(1, 10) into + // [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]] + // and would include 10 in the output: [1, 4, 7, 10] + size_t vals_per_col = (input_height + stride_height - 1) / stride_height; + size_t vals_per_row = (input_width + stride_width - 1) / stride_width; + + size_t x_batch_vals_count = input_width * input_height; + size_t out_batch_vals_count = vals_per_row * vals_per_col; + + // Copy values specified by the strides. + // Only NHWC is currently supported. + for (size_t n = 0; n < batch_size; n++) { + for (size_t h = 0; h < vals_per_col; h++) { + for (size_t w = 0; w < vals_per_row; w++) { + for (size_t c = 0; c < channels; c++) { + size_t x_n_index = n * x_batch_vals_count; + size_t x_hw_index = h * stride_height * input_width + + w * stride_width; + size_t x_nhw_index = x_n_index + x_hw_index; + size_t x_nhwc_index = c + channels * x_nhw_index; + + size_t out_n_index = n * out_batch_vals_count; + size_t out_hw_index = h * vals_per_row + w; + size_t out_nhw_index = out_n_index + out_hw_index; + size_t out_nhwc_index = c + channels * out_nhw_index; + + out_buf[out_nhwc_index] = x_buf[x_nhwc_index]; + } + } + } + } + return; +} + } // namespace util } // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/util.h b/tfjs-backend-wasm/src/cc/util.h index beb118ea58a..34c6bc4b8cc 100644 --- a/tfjs-backend-wasm/src/cc/util.h +++ b/tfjs-backend-wasm/src/cc/util.h @@ -152,6 +152,18 @@ const std::vector assert_and_get_broadcast_shape( const std::vector get_broadcast_dims( const std::vector in_shape, const std::vector out_shape); +// Generates the output for AvgPool, MaxPool, etc where xnnpack does not support +// a 1x1 filter. Applies batching, channels, and strides. +// TODONT(mattsoulanille): Padding support is not necessary because it is +// meaningless for a 1x1 kernel. It would be undefined for regions where the +// kernel does not overlap the input tensor. +// https://www.tensorflow.org/api_docs/python/tf/nn#difference_between_convolution_and_pooling_layers_2 +const void identity_pool(const size_t x_id, const float* x_buf, float* out_buf, + const size_t out_size, const size_t batch_size, + const size_t input_height, const size_t input_width, + const size_t stride_height, const size_t stride_width, + const size_t channels); + } // namespace util } // namespace tfjs #endif // UTIL_H_ diff --git a/tfjs-core/src/ops/avg_pool.ts b/tfjs-core/src/ops/avg_pool.ts index f29853c63a8..06cd89d3496 100644 --- a/tfjs-core/src/ops/avg_pool.ts +++ b/tfjs-core/src/ops/avg_pool.ts @@ -48,6 +48,8 @@ import {reshape} from './reshape'; * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. + * + * @doc {heading: 'Operations', subheading: 'Convolution'} */ function avgPool_( x: T|TensorLike, filterSize: [number, number]|number, diff --git a/tfjs-core/src/ops/avg_pool_test.ts b/tfjs-core/src/ops/avg_pool_test.ts index aea4b3590d0..2f9a19181d0 100644 --- a/tfjs-core/src/ops/avg_pool_test.ts +++ b/tfjs-core/src/ops/avg_pool_test.ts @@ -18,6 +18,7 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; +import {identityPoolTest} from './identity_pool_test'; describeWithFlags('avgPool', ALL_ENVS, () => { it('x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => { @@ -271,4 +272,6 @@ describeWithFlags('avgPool', ALL_ENVS, () => { const result = tf.avgPool(a, 1, 1, 0); expectArraysClose(await result.data(), [0]); }); + + identityPoolTest(tf.avgPool); }); diff --git a/tfjs-core/src/ops/identity_pool_test.ts b/tfjs-core/src/ops/identity_pool_test.ts new file mode 100644 index 00000000000..8611e8bc00b --- /dev/null +++ b/tfjs-core/src/ops/identity_pool_test.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {expectArraysClose} from '../test_util'; + +/** + * Test utility for testing AvgPool, MaxPool, etc where kernel size is 1x1, + * effectively making them act as the identity function except where strides + * affect the output. + */ +export function identityPoolTest(pool: typeof tf.avgPool) { + it('1x1 pool size (identity)', async () => { + // tslint:disable-next-line: no-unnecessary-type-assertion + const a = tf.range(0, 10).reshape([1, 1, 1, 10]) as tf.Tensor4D; + const result = pool(a, [1, 1], [1, 1], 'valid'); + expectArraysClose(await result.data(), await a.data()); + }); + + it('1x1 pool size with strides', async () => { + // tslint:disable-next-line: no-unnecessary-type-assertion + const a = tf.range(0, 150).reshape([1, 10, 15, 1]) as tf.Tensor4D; + const result = pool(a, [1, 1], [3, 4], 'valid'); + expectArraysClose(await result.data(), [ + 0, 4, 8, 12, + 45, 49, 53, 57, + 90, 94, 98, 102, + 135, 139, 143, 147, + ]); + }); + + it('1x1 pool size batched', async () => { + // 7 batches of 3 x 4 + const shape = [7, 3, 4, 1]; + const size = shape.reduce((a, b) => a * b, 1); + // tslint:disable-next-line: no-unnecessary-type-assertion + const a = tf.range(0, size).reshape(shape) as tf.Tensor4D; + const result = pool(a, [1, 1], [1, 1], 'valid'); + expectArraysClose(await result.data(), await a.data()); + }); + + it('1x1 pool size batched with strides', async () => { + // tslint:disable-next-line: no-unnecessary-type-assertion + const a = tf.range(0, 300).reshape([2, 10, 15, 1]) as tf.Tensor4D; + const result = pool(a, [1, 1], [3, 4], 'valid'); + expectArraysClose(await result.data(), [ + // Batch 0 + 0, 4, 8, 12, + 45, 49, 53, 57, + 90, 94, 98, 102, + 135, 139, 143, 147, + // Batch 1 + 150, 154, 158, 162, + 195, 199, 203, 207, + 240, 244, 248, 252, + 285, 289, 293, 297, + ]); + }); + + it('1x1 pool size batched with strides and channels', async () => { + // tslint:disable-next-line: no-unnecessary-type-assertion + const a = tf.range(0, 900).reshape([2, 10, 15, 3]) as tf.Tensor4D; + const result = pool(a, [1, 1], [3, 4], 'valid'); + expectArraysClose(await result.data(), [ + // Batch 0 + 0, 1, 2, 12, 13, 14, 24, 25, 26, 36, 37, 38, + 135, 136, 137, 147, 148, 149, 159, 160, 161, 171, 172, 173, + 270, 271, 272, 282, 283, 284, 294, 295, 296, 306, 307, 308, + 405, 406, 407, 417, 418, 419, 429, 430, 431, 441, 442, 443, + // Batch 1 + 450, 451, 452, 462, 463, 464, 474, 475, 476, 486, 487, 488, + 585, 586, 587, 597, 598, 599, 609, 610, 611, 621, 622, 623, + 720, 721, 722, 732, 733, 734, 744, 745, 746, 756, 757, 758, + 855, 856, 857, 867, 868, 869, 879, 880, 881, 891, 892, 893, + ]); + }); +} diff --git a/tfjs-core/src/ops/max_pool_test.ts b/tfjs-core/src/ops/max_pool_test.ts index 2827f4211f5..ba982369865 100644 --- a/tfjs-core/src/ops/max_pool_test.ts +++ b/tfjs-core/src/ops/max_pool_test.ts @@ -18,14 +18,15 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; +import { identityPoolTest } from './identity_pool_test'; describeWithFlags('maxPool', ALL_ENVS, () => { it('x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => { - const x = tf.tensor3d([0], [1, 1, 1]); + const x = tf.tensor3d([123], [1, 1, 1]); const result = tf.maxPool(x, 1, 1, 0); - expectArraysClose(await result.data(), [0]); + expectArraysClose(await result.data(), [123]); }); it('x=[3,3,1] f=[2,2] s=1, p=0', async () => { @@ -208,6 +209,8 @@ describeWithFlags('maxPool', ALL_ENVS, () => { const result = tf.maxPool(x, 1, 1, 0); expectArraysClose(await result.data(), [0]); }); + + identityPoolTest(tf.maxPool); }); describeWithFlags('maxPoolBackprop', ALL_ENVS, () => {