Skip to content

Commit

Permalink
Fix Python IndexError of case7: paddle.static.nn.spectral_norm (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#49988)

* add dim check for spectral_norm

* add unittest out of range for spectral_norm

* use ValueError when dim out of range for spectral_norm

* update dim limit and add unittest for spectral_norm
  • Loading branch information
RedContritio authored and pangengzheng committed Feb 2, 2023
1 parent b3d98c6 commit d4139e6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
18 changes: 18 additions & 0 deletions python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,24 @@ def test_weight_dtype():
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)

def test_dim_out_of_range_1():
weight_3 = np.random.random((2, 4)).astype("float32")
tensor_3 = paddle.to_tensor(weight_3)
paddle.static.nn.spectral_norm(
tensor_3, dim=1382376303, power_iters=2
)

# the dim must be 0 or 1
self.assertRaises(ValueError, test_dim_out_of_range_1)

def test_dim_out_of_range_2():
weight_4 = np.random.random((2, 4)).astype("float32")
tensor_4 = paddle.to_tensor(weight_4)
paddle.static.nn.spectral_norm(tensor_4, dim=-1, power_iters=2)

# the dim must be 0 or 1
self.assertRaises(ValueError, test_dim_out_of_range_2)


class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/static/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3418,11 +3418,12 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
# create intput and parameters
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), (
"The input `dim` should be less than the "
"rank of `weight`, but received dim="
"{}".format(dim)
)

if dim not in [0, 1]:
raise ValueError(
f"The input `dim` must be 0 (if weight in fc) or 1 (if weight in conv), but received dim={dim}"
)

h = input_shape[dim]
w = np.prod(input_shape) // h

Expand Down

0 comments on commit d4139e6

Please sign in to comment.