Skip to content

Commit

Permalink
Fix div 0 error of case11: paddle.nn.functional.max_pool1d/max_pool2d…
Browse files Browse the repository at this point in the history
…/max_pool3d (#50010)

* add stride check for MaxPool

* add unittests
  • Loading branch information
RedContritio authored Feb 1, 2023
1 parent e4e94a8 commit 3ab6faa
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 1 deletion.
5 changes: 5 additions & 0 deletions paddle/fluid/operators/pool_with_index_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ inline int MaxPoolOutputSize(int input_size,
int filter_size,
int padding,
int stride) {
PADDLE_ENFORCE_NE(
stride,
0,
phi::errors::InvalidArgument(
"The stride of MaxPool shall not be 0, but received %d.", stride));
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/funcs/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ inline int MaxPoolOutputSize(int input_size,
int filter_size,
int padding,
int stride) {
PADDLE_ENFORCE_NE(
stride,
0,
phi::errors::InvalidArgument(
"The stride of MaxPool shall not be 0, but received %d.", stride));
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/fluid/tests/unittests/test_pool1d_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_pool1d(self):
self.check_max_dygraph_return_index_results(place)


class TestPool2DError_API(unittest.TestCase):
class TestPool1DError_API(unittest.TestCase):
def test_error_api(self):
def run1():
with fluid.dygraph.guard():
Expand Down Expand Up @@ -417,6 +417,18 @@ def run_stride_out_of_range():

self.assertRaises(ValueError, run_stride_out_of_range)

def run_zero_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1]), dtype='float32'
)
out = F.max_pool1d(
x, 1, stride=0, padding=1, return_mask=True, ceil_mode=True
)

self.assertRaises(ValueError, run_zero_stride)


if __name__ == '__main__':
unittest.main()
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_pool2d_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,18 @@ def run_stride_out_of_range():

self.assertRaises(ValueError, run_stride_out_of_range)

def run_zero_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1, 1]), dtype='float32'
)
out = max_pool2d(
x, 1, stride=0, padding=1, return_mask=True, ceil_mode=True
)

self.assertRaises(ValueError, run_zero_stride)


if __name__ == '__main__':
unittest.main()
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_pool3d_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,18 @@ def run_size_out_of_range():

self.assertRaises(ValueError, run_size_out_of_range)

def run_zero_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1, 1, 1]), dtype='float32'
)
out = max_pool3d(
x, 1, stride=0, padding=1, return_mask=True, ceil_mode=True
)

self.assertRaises(ValueError, run_zero_stride)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3ab6faa

Please sign in to comment.