Skip to content

Commit

Permalink
fix div 0 error of fftfreq (#49954)
Browse files Browse the repository at this point in the history
* fix div 0 error of fftfreq

* fix div 0 error of fftfreq

* bug fix

* add 'n' value check
  • Loading branch information
Liyulingyue authored Jan 26, 2023
1 parent f43cb3b commit 7a0b0da
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,8 @@ def fftfreq(n, d=1.0, dtype=None, name=None):
# Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001])
"""
if d * n == 0:
raise ValueError("d or n should not be 0.")

dtype = paddle.framework.get_default_dtype()
val = 1.0 / (n * d)
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/fluid/tests/unittests/fft/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,23 @@ def test_fftfreq(self):
)


@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'n', 'd', 'dtype', 'expect_exception'),
[
('test_with_0_0', 0, 0, 'float32', ValueError),
('test_with_n_0', 20, 0, 'float32', ValueError),
('test_with_0_d', 0, 20, 'float32', ValueError),
],
)
class TestFftFreqException(unittest.TestCase):
def test_fftfreq2(self):
"""Test fftfreq with d = 0"""
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.fftfreq(self.n, self.d, self.dtype)


@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'n', 'd', 'dtype'),
Expand Down

0 comments on commit 7a0b0da

Please sign in to comment.