From 3f376994950f8a8979af41c9cc55bb1639af1a31 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 22 May 2018 07:18:44 +0000 Subject: [PATCH] add count_include_pad argument --- cpp-package/scripts/OpWrapperGenerator.py | 1 + .../scala/org/apache/mxnet/SymbolMacro.scala | 2 +- src/operator/nn/mkldnn/mkldnn_pooling.cc | 6 +- src/operator/nn/pool.cuh | 73 ++++++++++++++----- src/operator/nn/pool.h | 59 ++++++++++----- src/operator/nn/pooling-inl.h | 29 ++++++-- 6 files changed, 126 insertions(+), 44 deletions(-) diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 0c000d9955ff..8facde168408 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -77,6 +77,7 @@ def GetConvertEnumVariableToString(self, variable=''): class Arg: typeDict = {'boolean':'bool',\ + 'boolean or None':'dmlc::optional',\ 'Shape(tuple)':'Shape',\ 'Symbol':'Symbol',\ 'NDArray':'Symbol',\ diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 234a8604cb91..11ebc15af9d7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -190,7 +190,7 @@ private[mxnet] object SymbolImplMacros { case "long" | "long(non-negative)" => "Long" case "double" | "doubleorNone" => "Double" case "string" => "String" - case "boolean" => "Boolean" + case "boolean" => "BooleanorNone" case "tupleof" | "tupleof" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default, $argType") diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 259af2b94025..9fd88a13c465 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -121,7 +121,11 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { return mkldnn::algorithm::pooling_max; break; case pool_enum::kAvgPooling: - return mkldnn::algorithm::pooling_avg_include_padding; + if (param.count_include_pad.has_value() && !param.count_include_pad.value()) { + return mkldnn::algorithm::pooling_avg_exclude_padding; + } else { + return mkldnn::algorithm::pooling_avg_include_padding; + } break; default: LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method."; diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh index 9d004d295bed..7d51f60b2032 100644 --- a/src/operator/nn/pool.cuh +++ b/src/operator/nn/pool.cuh @@ -214,16 +214,19 @@ template __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* out_data, - const bool getAvg = false) { + const bool getAvg = false, const bool count_include_pad = true) { CUDA_KERNEL_LOOP(index, nthreads) { const int pw = index % pooled_width; const int c = (index / pooled_width) % channels; const int n = index / pooled_width / channels; int wstart = pw * stride_w - pad_w; int wend = min(wstart + kernel_w, width + pad_w); - const int pool_size = (getAvg? (wend - wstart) : 1); + int pool_size = (getAvg? (wend - wstart) : 1); wstart = max(wstart, 0); wend = min(wend, width); + if (getAvg && !count_include_pad) { + pool_size = (wend - wstart); + } DType sum = 0; const DType* out_slice = in_data + (n * channels + c) * width; for (int w = wstart; w < wend; ++w) { @@ -244,7 +247,8 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, DType* out_data, - const bool getAvg = false) { + const bool getAvg = false, + const bool count_include_pad = true) { CUDA_KERNEL_LOOP(index, nthreads) { const int pw = index % pooled_width; const int ph = (index / pooled_width) % pooled_height; @@ -254,11 +258,14 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, int wstart = pw * stride_w - pad_w; int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); - const int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1); + int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1); hstart = max(hstart, 0); wstart = max(wstart, 0); hend = min(hend, height); wend = min(wend, width); + if (getAvg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } DType sum = 0; const DType* out_slice = in_data + (n * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { @@ -282,7 +289,8 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, const int kernel_h, const int kernel_w, const int stride_d, const int stride_h, const int stride_w, const int pad_d, const int pad_h, const int pad_w, - DType* out_data, const bool getAvg = false) { + DType* out_data, const bool getAvg = false, + const bool count_include_pad = true) { CUDA_KERNEL_LOOP(index, nthreads) { const int pw = index % pooled_width; const int ph = (index / pooled_width) % pooled_height; @@ -295,13 +303,16 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, int dend = min(dstart + kernel_d, depth + pad_d); int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); - const int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); dend = min(dend, depth); hend = min(hend, height); wend = min(wend, width); + if (getAvg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } DType sum = 0; const DType* out_slice = in_data + (n * channels + c) * depth * height * width; for (int d = dstart; d < dend; ++d) { @@ -487,7 +498,8 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* in_grad, - const bool isAvg = false) { + const bool isAvg = false, + const bool count_include_pad = true) { // index is the input image index in NCW CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -507,6 +519,11 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr int wstart = pw * stride_w - pad_w; int wend = min(wstart + kernel_w, width + pad_w); int pool_size = (isAvg? (wend - wstart) : 1); + if (isAvg && !count_include_pad) { + wstart = max(wstart, 0); + wend = min(wend, width); + pool_size = (wend - wstart); + } gradient += lp_grad::Map(out_grad_slice[pw], in_data[index], out_data_slice[pw]) / pool_size; } @@ -528,7 +545,8 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, DType* in_grad, - const bool isAvg = false) { + const bool isAvg = false, + const bool count_include_pad = true) { // index is the input image index in NCHW CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -555,6 +573,13 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr int wend = min(wstart + kernel_w, width + pad_w); int pool_size = (isAvg? (hend - hstart) * (wend - wstart) : 1); int out_index = ph * pooled_width + pw; + if (isAvg && !count_include_pad) { + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + pool_size = (hend - hstart) * (wend - wstart); + } gradient += lp_grad::Map(out_grad_slice[out_index], in_data[index], @@ -580,7 +605,8 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr const int kernel_d, const int kernel_h, const int kernel_w, const int stride_d, const int stride_h, const int stride_w, const int pad_d, const int pad_h, - const int pad_w, DType* in_grad, const bool isAvg = false) { + const int pad_w, DType* in_grad, const bool isAvg = false, + const bool count_include_pad = true) { // index is the input image index in NCDHW CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -613,6 +639,15 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr int wend = min(wstart + kernel_w, width + pad_w); int pool_size = (isAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); int out_index = (pd * pooled_height + ph) * pooled_width + pw; + if (isAvg && !count_include_pad) { + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dend = min(dend, depth); + hend = min(hend, height); + wend = min(wend, width); + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } gradient += lp_grad::Map(out_grad_slice[out_index], in_data[index], out_data_slice[out_index]) / pool_size; @@ -643,7 +678,7 @@ template inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, - DType* out_data) { + DType* out_data, const bool count_include_pad) { CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations"; using namespace mxnet_op; if (kernel.ndim() == 1) { @@ -659,7 +694,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is pool_sum_1d_gpu_kernel<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], oshape[2], - kernel[0], stride[0], pad[0], out_data, true); + kernel[0], stride[0], pad[0], out_data, + true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -693,7 +729,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is 0, mshadow::Stream::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], - stride[0], stride[1], pad[0], pad[1], out_data, true); + stride[0], stride[1], pad[0], pad[1], out_data, + true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -731,7 +768,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], kernel[1], kernel[2], stride[0], stride[1], stride[2], - pad[0], pad[1], pad[2], out_data, true); + pad[0], pad[1], pad[2], out_data, true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -777,7 +814,8 @@ template inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - const int pool_type, OpReqType req_type, DType* in_grad) { + const int pool_type, OpReqType req_type, DType* in_grad, + const bool count_include_pad) { if (mxnet::kNullOp == req_type) return; if (mxnet::kAddTo != req_type) { mxnet_op::Kernel::Launch(s, ishape.Size(), in_grad); @@ -798,7 +836,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* 0, mshadow::Stream::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], oshape[2], kernel[0], - stride[0], pad[0], in_grad, true); + stride[0], pad[0], in_grad, true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -836,7 +874,8 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], - stride[0], stride[1], pad[0], pad[1], in_grad, true); + stride[0], stride[1], pad[0], pad[1], in_grad, + true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -878,7 +917,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], kernel[1], kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1], - pad[2], in_grad, true); + pad[2], in_grad, true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h index 9fe43b2bd468..e82fe9e0c7c0 100644 --- a/src/operator/nn/pool.h +++ b/src/operator/nn/pool.h @@ -216,7 +216,8 @@ inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TS template inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, const bool getAvg = false) { + DType* out_data, + const bool getAvg = false, const bool count_include_pad = true) { const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -232,6 +233,9 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS int pool_size = (getAvg ? (wend - wstart) : 1); wstart = std::max(wstart, 0); wend = std::min(wend, width); + if (getAvg && !count_include_pad) { + pool_size = (wend - wstart); + } DType sum = 0; for (int w = wstart; w < wend; ++w) { sum += a_pow_p::Map(in_data[w]) / pool_size; @@ -251,7 +255,8 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS template inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, const bool getAvg = false) { + DType* out_data, + const bool getAvg = false, const bool count_include_pad = true) { const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -272,6 +277,9 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS wstart = std::max(wstart, 0); hend = std::min(hend, height); wend = std::min(wend, width); + if (getAvg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } DType sum = 0; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -294,7 +302,8 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS template inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, const bool getAvg = false) { + DType* out_data, + const bool getAvg = false, const bool count_include_pad = true) { const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -320,6 +329,9 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS dend = std::min(dend, depth); hend = std::min(hend, height); wend = std::min(wend, width); + if (getAvg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } DType sum = 0; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -509,8 +521,8 @@ inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data, template inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad, const bool isAvg = false) { + const TShape& pad, const TShape& stride, DType* in_grad, + const bool isAvg = false, const bool count_include_pad = true) { const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -526,6 +538,9 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const int pool_size = (isAvg ? (wend - wstart) : 1); wstart = std::max(wstart, 0); wend = std::min(wend, width); + if (isAvg && !count_include_pad) { + pool_size = (wend - wstart); + } for (int w = wstart; w < wend; ++w) { in_grad[w] += lp_grad::Map(out_grad[pw], in_data[w], out_data[pw]) / pool_size; } @@ -545,8 +560,8 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const template inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad, const bool isAvg = false) { + const TShape& pad, const TShape& stride, DType* in_grad, + const bool isAvg = false, const bool count_include_pad = true) { const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -567,6 +582,9 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const wstart = std::max(wstart, 0); hend = std::min(hend, height); wend = std::min(wend, width); + if (isAvg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } const int pool_index = ph * pooled_width + pw; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -593,8 +611,8 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const template inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad, const bool isAvg = false) { + const TShape& pad, const TShape& stride, DType* in_grad, + const bool isAvg = false, const bool count_include_pad = true) { const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -620,6 +638,9 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const dend = std::min(dend, depth); hend = std::min(hend, height); wend = std::min(wend, width); + if (isAvg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } const int pool_index = (pd * pooled_height + ph) * pooled_width + pw; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -660,13 +681,14 @@ template inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, - DType* out_data) { + DType* out_data, const bool count_include_pad) { CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations"; if (kernel.ndim() == 1) { if (pool_enum::kMaxPooling == pool_type) { pool_max_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true); + pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kLpPooling == pool_type) { @@ -678,7 +700,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is if (pool_enum::kMaxPooling == pool_type) { pool_max_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true); + pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kLpPooling == pool_type) { @@ -690,7 +713,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is if (pool_enum::kMaxPooling == pool_type) { pool_max_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true); + pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kLpPooling == pool_type) { @@ -723,7 +747,8 @@ template inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - const int pool_type, OpReqType req_type, DType* in_grad, const int p_value = 2) { + const int pool_type, OpReqType req_type, DType* in_grad, + const bool count_include_pad) { if (mxnet::kNullOp == req_type) return; if (mxnet::kAddTo != req_type) { mxnet_op::Kernel::Launch(s, ishape.Size(), in_grad); @@ -733,7 +758,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* unpool_max_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kAvgPooling == pool_type) { unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true); + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kLpPooling == pool_type) { @@ -747,7 +772,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* unpool_max_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kAvgPooling == pool_type) { unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true); + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kLpPooling == pool_type) { @@ -761,7 +786,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* unpool_max_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kAvgPooling == pool_type) { unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true); + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kLpPooling == pool_type) { diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index a4770b49e857..395643a343e1 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -50,6 +50,7 @@ struct PoolingParam : public dmlc::Parameter { bool global_pool; bool cudnn_off; dmlc::optional p_value; + dmlc::optional count_include_pad; DMLC_DECLARE_PARAMETER(PoolingParam) { DMLC_DECLARE_FIELD(kernel).set_default(TShape()) // add default value here .enforce_nonzero() @@ -81,7 +82,13 @@ struct PoolingParam : public dmlc::Parameter { .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding."); DMLC_DECLARE_FIELD(p_value).set_default(dmlc::optional()) - .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling"); + .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling."); + + DMLC_DECLARE_FIELD(count_include_pad).set_default(dmlc::optional()) + .describe("Only used for AvgPool, specify whether to count padding elements for average" + "calculation. For example, with a 5*5 kernel on a 3*3 corner of a image," + "the sum of the 9 valid elements will be divided by 25 if this is set to true," + "or it will be divided by 9 if this is set to false"); } bool operator==(const PoolingParam& other) const { @@ -92,7 +99,8 @@ struct PoolingParam : public dmlc::Parameter { this->pooling_convention == other.pooling_convention && this->global_pool == other.global_pool && this->cudnn_off == other.cudnn_off && - this->p_value == other.p_value; + this->p_value == other.p_value && + this->count_include_pad == other.count_include_pad; } }; @@ -112,6 +120,7 @@ struct hash { ret = dmlc::HashCombine(ret, val.global_pool); ret = dmlc::HashCombine(ret, val.cudnn_off); ret = dmlc::HashCombine(ret, val.p_value); + ret = dmlc::HashCombine(ret, val.count_include_pad); return ret; } }; @@ -153,27 +162,29 @@ class PoolingOp { } const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ? param_.p_value.value() : 1; + const bool count_include_pad = (param_.count_include_pad.has_value()) ? + param_.count_include_pad.value() : true; switch (p_value) { case 1: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr()); + param_.pool_type, req, out_data.dptr(), count_include_pad); break; case 2: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr()); + param_.pool_type, req, out_data.dptr(), count_include_pad); break; case 3: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr()); + param_.pool_type, req, out_data.dptr(), count_include_pad); break; default: LOG(FATAL) << "p value of " << p_value << " is not supported yet..."; @@ -201,6 +212,8 @@ class PoolingOp { const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ? param_.p_value.value() : 1; + const bool count_include_pad = (param_.count_include_pad.has_value()) ? + param_.count_include_pad.value() : true; switch (p_value) { case 1: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -208,7 +221,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr()); + param_.pool_type, req, in_grad.dptr(), count_include_pad); break; case 2: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -216,7 +229,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr()); + param_.pool_type, req, in_grad.dptr(), count_include_pad); break; case 3: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -224,7 +237,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr()); + param_.pool_type, req, in_grad.dptr(), count_include_pad); break; default: LOG(FATAL) << "p value of " << p_value << " is not supported yet...";