From 58e86ad65b623eb9c725005967a8263c7c693371 Mon Sep 17 00:00:00 2001 From: BHmingyang Date: Mon, 22 Jul 2024 11:17:08 +0800 Subject: [PATCH] fix value_error not raised --- python/paddle/tensor/creation.py | 58 ++++++++++++++++++- .../legacy_test/test_linspace.py | 9 ++- 2 files changed, 63 insertions(+), 4 deletions(-) rename test/{deprecated => }/legacy_test/test_linspace.py (97%) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 396196b8d6d94..98989d302bb21 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -360,7 +360,61 @@ def linspace( if not isinstance(num, (Variable, paddle.pir.Value)): with device_guard("cpu"): tensor_num = fill_constant([1], 'int32', num, force_cpu=True) - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + return _C_ops.linspace( + tensor_start, + tensor_stop, + tensor_num, + dtype, + _current_expected_place(), + ) + elif in_pir_mode(): + helper = LayerHelper("linspace", **locals()) + + start_dtype = convert_dtype(tensor_start.dtype) + stop_dtype = convert_dtype(tensor_stop.dtype) + out_dtype = convert_dtype(dtype) + if isinstance(start, paddle.pir.Value): + check_dtype( + start.dtype, + 'start', + ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + 'linspace', + ) + else: + check_type(start, 'start', (int, float), 'linspace') + + if isinstance(stop, paddle.pir.Value): + check_dtype( + stop.dtype, + 'stop', + ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + 'linspace', + ) + else: + check_type(stop, 'stop', (int, float), 'linspace') + if isinstance(num, paddle.pir.Value): + check_dtype(num.dtype, 'num', ['int32'], 'linspace') + check_dtype( + dtype, + 'dtype', + ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + 'linspace', + ) + if ( + (stop_dtype == "float64" or start_dtype == "float64") + and out_dtype in ["float32", "int32"] + ) or ( + (stop_dtype == "int64" or start_dtype == "int64") + and out_dtype == "int32" + ): + raise ValueError( + f"The dtype of start/stop is {start_dtype}/{stop_dtype} but the attr(dtype) of linspace is {dtype}, " + "which may cause data type overflows. Please reset attr(dtype) of linspace." + ) + if isinstance(dtype, paddle.base.core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] + return _C_ops.linspace( tensor_start, tensor_stop, @@ -1270,7 +1324,7 @@ def eye( """ def _check_attr(attr, message): - if isinstance(attr, ((Variable, paddle.Tensor, paddle.pir.Value))): + if isinstance(attr, ((Variable, core.eager.Tensor, paddle.pir.Value))): assert len(attr.shape) == 1 and attr.shape[0] in [1, -1] elif not isinstance(attr, int) or attr < 0: raise TypeError(f"{message} should be a non-negative int.") diff --git a/test/deprecated/legacy_test/test_linspace.py b/test/legacy_test/test_linspace.py similarity index 97% rename from test/deprecated/legacy_test/test_linspace.py rename to test/legacy_test/test_linspace.py index 74d15296772f7..5e641832e2491 100644 --- a/test/deprecated/legacy_test/test_linspace.py +++ b/test/legacy_test/test_linspace.py @@ -19,7 +19,7 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core class TestLinspaceOpCommonCase(OpTest): @@ -168,6 +168,8 @@ def test_dtype(self): np.testing.assert_array_equal(res_1, res_2) def test_name(self): + if paddle.framework.use_pir_api(): + return with paddle_static_guard(): with paddle.static.program_guard(paddle.static.Program()): out = paddle.linspace( @@ -190,7 +192,9 @@ def test_imperative(self): class TestLinspaceOpError(unittest.TestCase): def test_errors(self): with paddle_static_guard(): - with program_guard(Program(), Program()): + with paddle.base.program_guard( + paddle.base.Program(), paddle.base.Program() + ): def test_dtype(): paddle.linspace(0, 10, 1, dtype="int8") @@ -223,6 +227,7 @@ def test_start_dtype(): ) paddle.linspace(start, 10, 1, dtype="float32") + # test_start_dtype() self.assertRaises(ValueError, test_start_dtype) def test_end_dtype():