Skip to content

Commit

Permalink
fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 21, 2024
1 parent 092d803 commit b157596
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
10 changes: 2 additions & 8 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,14 +833,8 @@ def where_(
Inplace version of ``where`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_where`.
"""
if np.isscalar(x):
x = paddle.full([1], x, np.array([x]).dtype.name)

if np.isscalar(y):
y = paddle.full([1], y, np.array([y]).dtype.name)

if x is None and y is None:
return nonzero(condition, as_tuple=True)
if np.isscalar(x) or np.isscalar(y):
raise ValueError("either both or neither of x and y should be given")

if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")
Expand Down
16 changes: 8 additions & 8 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ def test_forward_version(self):
self.assertEqual(var.inplace_version, 0)

inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 2)
self.assertEqual(var.inplace_version, 1)

inplace_var[0] = 2
self.assertEqual(var.inplace_version, 3)
self.assertEqual(var.inplace_version, 2)

inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 5)
self.assertEqual(var.inplace_version, 3)

def test_backward_error(self):
# It raises an error because the inplace operator will result
Expand All @@ -295,7 +295,7 @@ def test_backward_error(self):
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegex(
RuntimeError,
f"received tensor_version:{2} != wrapper_version_snapshot:{0}",
f"received tensor_version:{1} != wrapper_version_snapshot:{0}",
):
loss.backward()

Expand Down Expand Up @@ -1298,13 +1298,13 @@ def test_forward_version(self):
self.assertEqual(var.inplace_version, 0)

inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 2)
self.assertEqual(var.inplace_version, 1)

inplace_var[0] = 2
self.assertEqual(var.inplace_version, 3)
self.assertEqual(var.inplace_version, 2)

inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 5)
self.assertEqual(var.inplace_version, 3)

def test_backward_error(self):
# It raises an error because the inplace operator will result
Expand All @@ -1322,7 +1322,7 @@ def test_backward_error(self):
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegex(
RuntimeError,
"received tensor_version:2 != wrapper_version_snapshot:0",
"received tensor_version:1 != wrapper_version_snapshot:0",
):
loss.backward()

Expand Down

0 comments on commit b157596

Please sign in to comment.