Skip to content

Commit

Permalink
[wasm] Fix AvgPool and MaxPool for 1x1 kernels (#6969)
Browse files Browse the repository at this point in the history
XNNPack does not support 1x1 kernels for AvgPool or MaxPool. Implement these cases manually, including support for strides.
  • Loading branch information
mattsoulanille authored Oct 24, 2022
1 parent 2941e79 commit bdf092e
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 2 deletions.
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/AvgPool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
* limitations under the License.
* ===========================================================================*/

#include <cstring>
#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif
#include <stdio.h>

#include <xnnpack.h>
#include <array>
Expand Down Expand Up @@ -54,6 +56,14 @@ void AvgPool(const size_t x_id, const size_t batch_size,
const float* x_buf = reinterpret_cast<float*>(x_info.memory_offset);
float* out_buf = reinterpret_cast<float*>(out_info.memory_offset);

// XNNPack does not support 1x1 filters for AvgPool
if (filter_width == 1 && filter_height == 1) {
tfjs::util::identity_pool(x_id, x_buf, out_buf, out_info.size, batch_size,
input_height, input_width, stride_height,
stride_width, channels);
return;
}

xnn_operator_t avg_pool_op = nullptr;

const uint32_t flags = 0;
Expand Down
8 changes: 8 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/MaxPool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ void MaxPool(const size_t x_id, const size_t batch_size,
const float* x_buf = reinterpret_cast<float*>(x_info.memory_offset);
float* out_buf = reinterpret_cast<float*>(out_info.memory_offset);

// XNNPack does not support 1x1 filters for MaxPool
if (filter_width == 1 && filter_height == 1) {
tfjs::util::identity_pool(x_id, x_buf, out_buf, out_info.size, batch_size,
input_height, input_width, stride_height,
stride_width, input_channels);
return;
}

xnn_operator_t max_pool_op = nullptr;

const uint32_t flags = 0;
Expand Down
49 changes: 49 additions & 0 deletions tfjs-backend-wasm/src/cc/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <algorithm>
#include <cstddef>
#include <cstring>
#include <vector>

#include "tfjs-backend-wasm/src/cc/util.h"
Expand Down Expand Up @@ -87,5 +88,53 @@ const std::vector<size_t> get_broadcast_dims(
return dims;
}

const void identity_pool(const size_t x_id, const float* x_buf, float* out_buf,
const size_t out_size, const size_t batch_size,
const size_t input_height, const size_t input_width,
const size_t stride_height, const size_t stride_width,
const size_t channels) {
// Early bailout for the identity case to use memcpy for efficiency.
if (stride_width == 1 && stride_height == 1) {
std::memcpy(out_buf, x_buf, out_size * sizeof(*out_buf));
return;
}

// Values per row and column are determined by the stride size.
// ceil(input_height / stride_height) instead of floor because strides do
// not guarantee that more than one value is available.
// e.g. a stride of 3 would 'partition' range(1, 10) into
// [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]
// and would include 10 in the output: [1, 4, 7, 10]
size_t vals_per_col = (input_height + stride_height - 1) / stride_height;
size_t vals_per_row = (input_width + stride_width - 1) / stride_width;

size_t x_batch_vals_count = input_width * input_height;
size_t out_batch_vals_count = vals_per_row * vals_per_col;

// Copy values specified by the strides.
// Only NHWC is currently supported.
for (size_t n = 0; n < batch_size; n++) {
for (size_t h = 0; h < vals_per_col; h++) {
for (size_t w = 0; w < vals_per_row; w++) {
for (size_t c = 0; c < channels; c++) {
size_t x_n_index = n * x_batch_vals_count;
size_t x_hw_index = h * stride_height * input_width
+ w * stride_width;
size_t x_nhw_index = x_n_index + x_hw_index;
size_t x_nhwc_index = c + channels * x_nhw_index;

size_t out_n_index = n * out_batch_vals_count;
size_t out_hw_index = h * vals_per_row + w;
size_t out_nhw_index = out_n_index + out_hw_index;
size_t out_nhwc_index = c + channels * out_nhw_index;

out_buf[out_nhwc_index] = x_buf[x_nhwc_index];
}
}
}
}
return;
}

} // namespace util
} // namespace tfjs
12 changes: 12 additions & 0 deletions tfjs-backend-wasm/src/cc/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ const std::vector<size_t> assert_and_get_broadcast_shape(
const std::vector<size_t> get_broadcast_dims(
const std::vector<size_t> in_shape, const std::vector<size_t> out_shape);

// Generates the output for AvgPool, MaxPool, etc where xnnpack does not support
// a 1x1 filter. Applies batching, channels, and strides.
// TODONT(mattsoulanille): Padding support is not necessary because it is
// meaningless for a 1x1 kernel. It would be undefined for regions where the
// kernel does not overlap the input tensor.
// https://www.tensorflow.org/api_docs/python/tf/nn#difference_between_convolution_and_pooling_layers_2
const void identity_pool(const size_t x_id, const float* x_buf, float* out_buf,
const size_t out_size, const size_t batch_size,
const size_t input_height, const size_t input_width,
const size_t stride_height, const size_t stride_width,
const size_t channels);

} // namespace util
} // namespace tfjs
#endif // UTIL_H_
2 changes: 2 additions & 0 deletions tfjs-core/src/ops/avg_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ import {reshape} from './reshape';
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function avgPool_<T extends Tensor3D|Tensor4D>(
x: T|TensorLike, filterSize: [number, number]|number,
Expand Down
3 changes: 3 additions & 0 deletions tfjs-core/src/ops/avg_pool_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';
import {identityPoolTest} from './identity_pool_test';

describeWithFlags('avgPool', ALL_ENVS, () => {
it('x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => {
Expand Down Expand Up @@ -271,4 +272,6 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
const result = tf.avgPool(a, 1, 1, 0);
expectArraysClose(await result.data(), [0]);
});

identityPoolTest(tf.avgPool);
});
91 changes: 91 additions & 0 deletions tfjs-core/src/ops/identity_pool_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '../index';
import {expectArraysClose} from '../test_util';

/**
* Test utility for testing AvgPool, MaxPool, etc where kernel size is 1x1,
* effectively making them act as the identity function except where strides
* affect the output.
*/
export function identityPoolTest(pool: typeof tf.avgPool) {
it('1x1 pool size (identity)', async () => {
// tslint:disable-next-line: no-unnecessary-type-assertion
const a = tf.range(0, 10).reshape([1, 1, 1, 10]) as tf.Tensor4D;
const result = pool(a, [1, 1], [1, 1], 'valid');
expectArraysClose(await result.data(), await a.data());
});

it('1x1 pool size with strides', async () => {
// tslint:disable-next-line: no-unnecessary-type-assertion
const a = tf.range(0, 150).reshape([1, 10, 15, 1]) as tf.Tensor4D;
const result = pool(a, [1, 1], [3, 4], 'valid');
expectArraysClose(await result.data(), [
0, 4, 8, 12,
45, 49, 53, 57,
90, 94, 98, 102,
135, 139, 143, 147,
]);
});

it('1x1 pool size batched', async () => {
// 7 batches of 3 x 4
const shape = [7, 3, 4, 1];
const size = shape.reduce((a, b) => a * b, 1);
// tslint:disable-next-line: no-unnecessary-type-assertion
const a = tf.range(0, size).reshape(shape) as tf.Tensor4D;
const result = pool(a, [1, 1], [1, 1], 'valid');
expectArraysClose(await result.data(), await a.data());
});

it('1x1 pool size batched with strides', async () => {
// tslint:disable-next-line: no-unnecessary-type-assertion
const a = tf.range(0, 300).reshape([2, 10, 15, 1]) as tf.Tensor4D;
const result = pool(a, [1, 1], [3, 4], 'valid');
expectArraysClose(await result.data(), [
// Batch 0
0, 4, 8, 12,
45, 49, 53, 57,
90, 94, 98, 102,
135, 139, 143, 147,
// Batch 1
150, 154, 158, 162,
195, 199, 203, 207,
240, 244, 248, 252,
285, 289, 293, 297,
]);
});

it('1x1 pool size batched with strides and channels', async () => {
// tslint:disable-next-line: no-unnecessary-type-assertion
const a = tf.range(0, 900).reshape([2, 10, 15, 3]) as tf.Tensor4D;
const result = pool(a, [1, 1], [3, 4], 'valid');
expectArraysClose(await result.data(), [
// Batch 0
0, 1, 2, 12, 13, 14, 24, 25, 26, 36, 37, 38,
135, 136, 137, 147, 148, 149, 159, 160, 161, 171, 172, 173,
270, 271, 272, 282, 283, 284, 294, 295, 296, 306, 307, 308,
405, 406, 407, 417, 418, 419, 429, 430, 431, 441, 442, 443,
// Batch 1
450, 451, 452, 462, 463, 464, 474, 475, 476, 486, 487, 488,
585, 586, 587, 597, 598, 599, 609, 610, 611, 621, 622, 623,
720, 721, 722, 732, 733, 734, 744, 745, 746, 756, 757, 758,
855, 856, 857, 867, 868, 869, 879, 880, 881, 891, 892, 893,
]);
});
}
7 changes: 5 additions & 2 deletions tfjs-core/src/ops/max_pool_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';
import { identityPoolTest } from './identity_pool_test';

describeWithFlags('maxPool', ALL_ENVS, () => {
it('x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => {
const x = tf.tensor3d([0], [1, 1, 1]);
const x = tf.tensor3d([123], [1, 1, 1]);

const result = tf.maxPool(x, 1, 1, 0);

expectArraysClose(await result.data(), [0]);
expectArraysClose(await result.data(), [123]);
});

it('x=[3,3,1] f=[2,2] s=1, p=0', async () => {
Expand Down Expand Up @@ -208,6 +209,8 @@ describeWithFlags('maxPool', ALL_ENVS, () => {
const result = tf.maxPool(x, 1, 1, 0);
expectArraysClose(await result.data(), [0]);
});

identityPoolTest(tf.maxPool);
});

describeWithFlags('maxPoolBackprop', ALL_ENVS, () => {
Expand Down

0 comments on commit bdf092e

Please sign in to comment.