diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 58a20c37661cf..0d75bb92a3813 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -770,58 +770,83 @@ def where( if x is None or y is None: raise ValueError("either both or neither of x and y should be given") + # NOTE: We might need to adapt the broadcast_shape and broadcast_to for dynamic shape + # so dynamic and pir branch can be merged into one code block condition_shape = list(condition.shape) x_shape = list(x.shape) y_shape = list(y.shape) - if x_shape == y_shape and condition_shape == x_shape: - broadcast_condition = condition + if in_dynamic_mode(): + broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) + broadcast_shape = paddle.broadcast_shape( + broadcast_shape, condition_shape + ) + broadcast_x = x broadcast_y = y - else: - zeros_like_x = paddle.zeros_like(x) - zeros_like_y = paddle.zeros_like(y) - zeros_like_condition = paddle.zeros_like(condition) - zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype) - cast_cond = paddle.cast(condition, x.dtype) - - broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y) - broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition) - broadcast_x = paddle.add(x, broadcast_zeros) - broadcast_y = paddle.add(y, broadcast_zeros) - broadcast_condition = paddle.add(cast_cond, broadcast_zeros) - broadcast_condition = paddle.cast(broadcast_condition, 'bool') + broadcast_condition = condition + + if condition_shape != broadcast_shape: + broadcast_condition = paddle.broadcast_to( + broadcast_condition, broadcast_shape + ) + if x_shape != broadcast_shape: + broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape) + if y_shape != broadcast_shape: + broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape) - if in_dynamic_or_pir_mode(): return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) + else: - check_variable_and_dtype(condition, 'condition', ['bool'], 'where') - check_variable_and_dtype( - x, - 'x', - ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'where', - ) - check_variable_and_dtype( - y, - 'y', - ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'where', - ) - helper = LayerHelper("where", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + # for PIR and old IR + if x_shape == y_shape and condition_shape == x_shape: + broadcast_condition = condition + broadcast_x = x + broadcast_y = y + else: + zeros_like_x = paddle.zeros_like(x) + zeros_like_y = paddle.zeros_like(y) + zeros_like_condition = paddle.zeros_like(condition) + zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype) + cast_cond = paddle.cast(condition, x.dtype) + + broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y) + broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition) + broadcast_x = paddle.add(x, broadcast_zeros) + broadcast_y = paddle.add(y, broadcast_zeros) + broadcast_condition = paddle.add(cast_cond, broadcast_zeros) + broadcast_condition = paddle.cast(broadcast_condition, 'bool') - helper.append_op( - type='where', - inputs={ - 'Condition': broadcast_condition, - 'X': broadcast_x, - 'Y': broadcast_y, - }, - outputs={'Out': [out]}, - ) + if in_pir_mode(): + return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) + else: + check_variable_and_dtype(condition, 'condition', ['bool'], 'where') + check_variable_and_dtype( + x, + 'x', + ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'where', + ) + check_variable_and_dtype( + y, + 'y', + ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'where', + ) + helper = LayerHelper("where", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='where', + inputs={ + 'Condition': broadcast_condition, + 'X': broadcast_x, + 'Y': broadcast_y, + }, + outputs={'Out': [out]}, + ) - return out + return out @inplace_apis_in_dygraph_only @@ -844,23 +869,22 @@ def where_( condition_shape = list(condition.shape) x_shape = list(x.shape) y_shape = list(y.shape) - if x_shape == y_shape and condition_shape == x_shape: - broadcast_condition = condition - broadcast_x = x - broadcast_y = y - else: - zeros_like_x = paddle.zeros_like(x) - zeros_like_y = paddle.zeros_like(y) - zeros_like_condition = paddle.zeros_like(condition) - zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype) - cast_cond = paddle.cast(condition, x.dtype) - - broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y) - broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition) - broadcast_x = x.add_(broadcast_zeros) - broadcast_y = paddle.add(y, broadcast_zeros) - broadcast_condition = paddle.add(cast_cond, broadcast_zeros) - broadcast_condition = paddle.cast(broadcast_condition, 'bool') + + broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) + broadcast_shape = paddle.broadcast_shape(broadcast_shape, condition_shape) + + broadcast_x = x + broadcast_y = y + broadcast_condition = condition + + if condition_shape != broadcast_shape: + broadcast_condition = paddle.broadcast_to( + broadcast_condition, broadcast_shape + ) + if x_shape != broadcast_shape: + broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape) + if y_shape != broadcast_shape: + broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape) if in_dynamic_mode(): return _C_ops.where_(broadcast_condition, broadcast_x, broadcast_y) diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index bb704f600857a..8095110678f9a 100755 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -271,13 +271,13 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 0) inplace_var = self.inplace_api_processing(var) - self.assertEqual(var.inplace_version, 2) + self.assertEqual(var.inplace_version, 1) inplace_var[0] = 2 - self.assertEqual(var.inplace_version, 3) + self.assertEqual(var.inplace_version, 2) inplace_var = self.inplace_api_processing(inplace_var) - self.assertEqual(var.inplace_version, 5) + self.assertEqual(var.inplace_version, 3) def test_backward_error(self): # It raises an error because the inplace operator will result @@ -295,7 +295,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + f"received tensor_version:{1} != wrapper_version_snapshot:{0}", ): loss.backward() @@ -1298,13 +1298,13 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 0) inplace_var = self.inplace_api_processing(var) - self.assertEqual(var.inplace_version, 2) + self.assertEqual(var.inplace_version, 1) inplace_var[0] = 2 - self.assertEqual(var.inplace_version, 3) + self.assertEqual(var.inplace_version, 2) inplace_var = self.inplace_api_processing(inplace_var) - self.assertEqual(var.inplace_version, 5) + self.assertEqual(var.inplace_version, 3) def test_backward_error(self): # It raises an error because the inplace operator will result @@ -1322,7 +1322,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - "received tensor_version:2 != wrapper_version_snapshot:0", + "received tensor_version:1 != wrapper_version_snapshot:0", ): loss.backward() diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 8b5d967e32ba1..626d98aabf4f1 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -276,9 +276,15 @@ def test_api_broadcast(self, use_cuda=False): with paddle.static.program_guard(main_program): x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32') y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32') - x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32') - y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype( - 'float32' + x_i = ( + np.array([[0.9383, 0.1983, 3.2, 1.2]]) + .astype('float32') + .reshape([1, 4, 1]) + ) + y_i = ( + np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]) + .astype('float32') + .reshape([1, 4, 2]) ) result = paddle.where((x > 1), x=x, y=y) for use_cuda in [False, True]: @@ -805,6 +811,88 @@ def test_where_condition(self): np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) +class TestWhereDygraphAPIBroadcast(unittest.TestCase): + def test_broadcast_scalar(self): + with base.dygraph.guard(): + x_i = np.random.randn(4, 5, 6).astype('float64') + y_i = -1.0 + cond_i = np.random.randn(1, 1, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_broadcast_to_x(self): + with base.dygraph.guard(): + x_i = np.random.randn(4, 5, 6).astype('float64') + y_i = np.random.randn(1, 5, 6).astype('float64') + cond_i = np.random.randn(1, 1, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_broadcast_to_y(self): + with base.dygraph.guard(): + x_i = np.random.randn(1, 5, 6).astype('float64') + y_i = np.random.randn(4, 5, 6).astype('float64') + cond_i = np.random.randn(1, 1, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_broadcast_to_cond(self): + with base.dygraph.guard(): + x_i = np.random.randn(1, 1, 6).astype('float64') + y_i = np.random.randn(1, 5, 1).astype('float64') + cond_i = np.random.randn(4, 5, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_can_not_broadcast(self): + with base.dygraph.guard(): + x_i = np.random.randn(1, 1, 6).astype('float64') + y_i = np.random.randn(1, 5, 3).astype('float64') + cond_i = np.random.randn(4, 5, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + + with self.assertRaises(ValueError): + _ = paddle.where(cond, x, y) + + +class TestWhereDygraphAPIDtypePromotion(unittest.TestCase): + def test_dtype_auto_promotion_float(self): + with base.dygraph.guard(): + x_i = np.random.randn(4, 5, 6).astype('float32') + y_i = np.random.randn(4, 5, 6).astype('float64') + cond_i = np.random.randn(4, 5, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + self.assertEqual(out.dtype, y.dtype) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + class TestWhereOpError(unittest.TestCase): def test_errors(self): with paddle.static.program_guard( diff --git a/test/xpu/test_where_op_xpu.py b/test/xpu/test_where_op_xpu.py index 10dd2fa13a956..71e6c8996fcfd 100644 --- a/test/xpu/test_where_op_xpu.py +++ b/test/xpu/test_where_op_xpu.py @@ -172,9 +172,15 @@ def test_api_broadcast(self, use_cuda=False): with base.program_guard(train_prog, startup): x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32') y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32') - x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") - y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype( - "float32" + x_i = ( + np.array([[0.9383, 0.1983, 3.2, 1.2]]) + .astype("float32") + .reshape([1, 4, 1]) + ) + y_i = ( + np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]) + .astype("float32") + .reshape([1, 4, 2]) ) result = paddle.where(x > 1, x=x, y=y)