Skip to content

Commit

Permalink
fix test_multiply (PaddlePaddle#66492)
Browse files Browse the repository at this point in the history
  • Loading branch information
BHmingyang authored and inaomIIsfarell committed Jul 31, 2024
1 parent c94528e commit 91521c9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
1 change: 0 additions & 1 deletion test/legacy_test/test_linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def test_start_dtype():
)
paddle.linspace(start, 10, 1, dtype="float32")

# test_start_dtype()
self.assertRaises(ValueError, test_start_dtype)

def test_end_dtype():
Expand Down
15 changes: 8 additions & 7 deletions test/legacy_test/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import numpy as np

import paddle
from paddle import tensor
from paddle.static import Program, program_guard
from paddle import static, tensor
from paddle.base.framework import in_pir_mode


class TestMultiplyApi(unittest.TestCase):
def _run_static_graph_case(self, x_data, y_data):
with program_guard(Program(), Program()):
with static.program_guard(static.Program(), static.Program()):
paddle.enable_static()
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype
Expand Down Expand Up @@ -110,13 +110,14 @@ class TestMultiplyError(unittest.TestCase):
def test_errors(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
with program_guard(Program(), Program()):
with static.program_guard(static.Program(), static.Program()):
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, tensor.multiply, x, y)
if not in_pir_mode():
self.assertRaises(TypeError, tensor.multiply, x, y)

# test static computation graph: inputs must be broadcastable
with program_guard(Program(), Program()):
with static.program_guard(static.Program(), static.Program()):
x = paddle.static.data(name='x', shape=[20, 50], dtype=np.float64)
y = paddle.static.data(name='y', shape=[20], dtype=np.float64)
self.assertRaises(ValueError, tensor.multiply, x, y)
Expand Down Expand Up @@ -183,7 +184,7 @@ def test_errors(self):

class TestMultiplyInplaceApi(TestMultiplyApi):
def _run_static_graph_case(self, x_data, y_data):
with program_guard(Program(), Program()):
with static.program_guard(static.Program(), static.Program()):
paddle.enable_static()
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype
Expand Down

0 comments on commit 91521c9

Please sign in to comment.