diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 549b16fb675732..29177d9f8aeb62 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -833,14 +833,8 @@ def where_( Inplace version of ``where`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_where`. """ - if np.isscalar(x): - x = paddle.full([1], x, np.array([x]).dtype.name) - - if np.isscalar(y): - y = paddle.full([1], y, np.array([y]).dtype.name) - - if x is None and y is None: - return nonzero(condition, as_tuple=True) + if np.isscalar(x) or np.isscalar(y): + raise ValueError("either both or neither of x and y should be given") if x is None or y is None: raise ValueError("either both or neither of x and y should be given") diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index bb704f600857ae..8095110678f9a9 100755 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -271,13 +271,13 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 0) inplace_var = self.inplace_api_processing(var) - self.assertEqual(var.inplace_version, 2) + self.assertEqual(var.inplace_version, 1) inplace_var[0] = 2 - self.assertEqual(var.inplace_version, 3) + self.assertEqual(var.inplace_version, 2) inplace_var = self.inplace_api_processing(inplace_var) - self.assertEqual(var.inplace_version, 5) + self.assertEqual(var.inplace_version, 3) def test_backward_error(self): # It raises an error because the inplace operator will result @@ -295,7 +295,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + f"received tensor_version:{1} != wrapper_version_snapshot:{0}", ): loss.backward() @@ -1298,13 +1298,13 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 0) inplace_var = self.inplace_api_processing(var) - self.assertEqual(var.inplace_version, 2) + self.assertEqual(var.inplace_version, 1) inplace_var[0] = 2 - self.assertEqual(var.inplace_version, 3) + self.assertEqual(var.inplace_version, 2) inplace_var = self.inplace_api_processing(inplace_var) - self.assertEqual(var.inplace_version, 5) + self.assertEqual(var.inplace_version, 3) def test_backward_error(self): # It raises an error because the inplace operator will result @@ -1322,7 +1322,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - "received tensor_version:2 != wrapper_version_snapshot:0", + "received tensor_version:1 != wrapper_version_snapshot:0", ): loss.backward()