Skip to content

Commit

Permalink
fix div 0 error in conv1/2/3 (#49999)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyulingyue authored Feb 7, 2023
1 parent 36e5de8 commit 7a0fdeb
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,13 @@ void ConvInferMeta(const MetaTensor& input,
const bool channel_last = (config.is_run_mkldnn_kernel == false) &&
(data_format == "NHWC" || data_format == "NDHWC");

for (int i = 0; i < 2; ++i) {
PADDLE_ENFORCE_NE(in_dims[i],
0,
phi::errors::InvalidArgument(
"The size of Op(Conv) inputs should not be 0."));
}

PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5,
true,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/test_functional_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,17 @@ def setUp(self):
self.data_format = "NCL"


class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError):
def setUp(self):
self.input = np.random.randn(0, 0, 0)
self.filter = np.random.randn(1, 0, 0)
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCL"


if __name__ == "__main__":
unittest.main()
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_functional_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,20 @@ def setUp(self):
self.data_format = "NCHW"


class TestFunctionalConv2DErrorCase14(TestFunctionalConv2DErrorCase12):
def setUp(self):
self.input = np.random.randn(0, 0, 0, 0)
self.filter = np.random.randn(1, 0, 0, 0)
self.num_filters = 0
self.filter_size = 0
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCHW"


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_functional_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,20 @@ def setUp(self):
self.data_format = "NCDHW"


class TestFunctionalConv3DErrorCase13(TestFunctionalConv3DErrorCase11):
def setUp(self):
self.input = np.random.randn(0, 0, 0, 0, 0)
self.filter = np.random.randn(1, 0, 0, 0, 0)
self.num_filters = 1
self.filter_size = 1
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCDHW"


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit 7a0fdeb

Please sign in to comment.