From 092d80378bed0c832bf83795ecbf51cf6275107f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 20 Nov 2024 20:06:55 +0800 Subject: [PATCH 1/4] optimize where --- python/paddle/tensor/search.py | 75 ++++++++++++------------ test/legacy_test/test_where_op.py | 94 ++++++++++++++++++++++++++++++- test/xpu/test_where_op_xpu.py | 12 +++- 3 files changed, 139 insertions(+), 42 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 58a20c37661cf9..549b16fb675732 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -774,23 +774,21 @@ def where( x_shape = list(x.shape) y_shape = list(y.shape) - if x_shape == y_shape and condition_shape == x_shape: - broadcast_condition = condition - broadcast_x = x - broadcast_y = y - else: - zeros_like_x = paddle.zeros_like(x) - zeros_like_y = paddle.zeros_like(y) - zeros_like_condition = paddle.zeros_like(condition) - zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype) - cast_cond = paddle.cast(condition, x.dtype) - - broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y) - broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition) - broadcast_x = paddle.add(x, broadcast_zeros) - broadcast_y = paddle.add(y, broadcast_zeros) - broadcast_condition = paddle.add(cast_cond, broadcast_zeros) - broadcast_condition = paddle.cast(broadcast_condition, 'bool') + broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) + broadcast_shape = paddle.broadcast_shape(broadcast_shape, condition_shape) + + broadcast_x = x + broadcast_y = y + broadcast_condition = condition + + if condition_shape != broadcast_shape: + broadcast_condition = paddle.broadcast_to( + broadcast_condition, broadcast_shape + ) + if x_shape != broadcast_shape: + broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape) + if y_shape != broadcast_shape: + broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape) if in_dynamic_or_pir_mode(): return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) @@ -835,8 +833,14 @@ def where_( Inplace version of ``where`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_where`. """ - if np.isscalar(x) or np.isscalar(y): - raise ValueError("either both or neither of x and y should be given") + if np.isscalar(x): + x = paddle.full([1], x, np.array([x]).dtype.name) + + if np.isscalar(y): + y = paddle.full([1], y, np.array([y]).dtype.name) + + if x is None and y is None: + return nonzero(condition, as_tuple=True) if x is None or y is None: raise ValueError("either both or neither of x and y should be given") @@ -844,23 +848,22 @@ def where_( condition_shape = list(condition.shape) x_shape = list(x.shape) y_shape = list(y.shape) - if x_shape == y_shape and condition_shape == x_shape: - broadcast_condition = condition - broadcast_x = x - broadcast_y = y - else: - zeros_like_x = paddle.zeros_like(x) - zeros_like_y = paddle.zeros_like(y) - zeros_like_condition = paddle.zeros_like(condition) - zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype) - cast_cond = paddle.cast(condition, x.dtype) - - broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y) - broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition) - broadcast_x = x.add_(broadcast_zeros) - broadcast_y = paddle.add(y, broadcast_zeros) - broadcast_condition = paddle.add(cast_cond, broadcast_zeros) - broadcast_condition = paddle.cast(broadcast_condition, 'bool') + + broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) + broadcast_shape = paddle.broadcast_shape(broadcast_shape, condition_shape) + + broadcast_x = x + broadcast_y = y + broadcast_condition = condition + + if condition_shape != broadcast_shape: + broadcast_condition = paddle.broadcast_to( + broadcast_condition, broadcast_shape + ) + if x_shape != broadcast_shape: + broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape) + if y_shape != broadcast_shape: + broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape) if in_dynamic_mode(): return _C_ops.where_(broadcast_condition, broadcast_x, broadcast_y) diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 8b5d967e32ba13..626d98aabf4f1c 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -276,9 +276,15 @@ def test_api_broadcast(self, use_cuda=False): with paddle.static.program_guard(main_program): x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32') y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32') - x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32') - y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype( - 'float32' + x_i = ( + np.array([[0.9383, 0.1983, 3.2, 1.2]]) + .astype('float32') + .reshape([1, 4, 1]) + ) + y_i = ( + np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]) + .astype('float32') + .reshape([1, 4, 2]) ) result = paddle.where((x > 1), x=x, y=y) for use_cuda in [False, True]: @@ -805,6 +811,88 @@ def test_where_condition(self): np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) +class TestWhereDygraphAPIBroadcast(unittest.TestCase): + def test_broadcast_scalar(self): + with base.dygraph.guard(): + x_i = np.random.randn(4, 5, 6).astype('float64') + y_i = -1.0 + cond_i = np.random.randn(1, 1, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_broadcast_to_x(self): + with base.dygraph.guard(): + x_i = np.random.randn(4, 5, 6).astype('float64') + y_i = np.random.randn(1, 5, 6).astype('float64') + cond_i = np.random.randn(1, 1, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_broadcast_to_y(self): + with base.dygraph.guard(): + x_i = np.random.randn(1, 5, 6).astype('float64') + y_i = np.random.randn(4, 5, 6).astype('float64') + cond_i = np.random.randn(1, 1, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_broadcast_to_cond(self): + with base.dygraph.guard(): + x_i = np.random.randn(1, 1, 6).astype('float64') + y_i = np.random.randn(1, 5, 1).astype('float64') + cond_i = np.random.randn(4, 5, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + def test_can_not_broadcast(self): + with base.dygraph.guard(): + x_i = np.random.randn(1, 1, 6).astype('float64') + y_i = np.random.randn(1, 5, 3).astype('float64') + cond_i = np.random.randn(4, 5, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + + with self.assertRaises(ValueError): + _ = paddle.where(cond, x, y) + + +class TestWhereDygraphAPIDtypePromotion(unittest.TestCase): + def test_dtype_auto_promotion_float(self): + with base.dygraph.guard(): + x_i = np.random.randn(4, 5, 6).astype('float32') + y_i = np.random.randn(4, 5, 6).astype('float64') + cond_i = np.random.randn(4, 5, 6).astype('bool') + x = paddle.to_tensor(x_i) + y = paddle.to_tensor(y_i) + cond = paddle.to_tensor(cond_i) + out = paddle.where(cond, x, y) + self.assertEqual(out.dtype, y.dtype) + np.testing.assert_array_equal( + out.numpy(), np.where(cond_i, x_i, y_i) + ) + + class TestWhereOpError(unittest.TestCase): def test_errors(self): with paddle.static.program_guard( diff --git a/test/xpu/test_where_op_xpu.py b/test/xpu/test_where_op_xpu.py index 10dd2fa13a9566..71e6c8996fcfd5 100644 --- a/test/xpu/test_where_op_xpu.py +++ b/test/xpu/test_where_op_xpu.py @@ -172,9 +172,15 @@ def test_api_broadcast(self, use_cuda=False): with base.program_guard(train_prog, startup): x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32') y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32') - x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") - y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype( - "float32" + x_i = ( + np.array([[0.9383, 0.1983, 3.2, 1.2]]) + .astype("float32") + .reshape([1, 4, 1]) + ) + y_i = ( + np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]) + .astype("float32") + .reshape([1, 4, 2]) ) result = paddle.where(x > 1, x=x, y=y) From b157596b69b981d4b697e11403b1c1aa3d5458a6 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 21 Nov 2024 20:23:56 +0800 Subject: [PATCH 2/4] fix code --- python/paddle/tensor/search.py | 10 ++-------- test/legacy_test/test_inplace.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 549b16fb675732..29177d9f8aeb62 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -833,14 +833,8 @@ def where_( Inplace version of ``where`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_where`. """ - if np.isscalar(x): - x = paddle.full([1], x, np.array([x]).dtype.name) - - if np.isscalar(y): - y = paddle.full([1], y, np.array([y]).dtype.name) - - if x is None and y is None: - return nonzero(condition, as_tuple=True) + if np.isscalar(x) or np.isscalar(y): + raise ValueError("either both or neither of x and y should be given") if x is None or y is None: raise ValueError("either both or neither of x and y should be given") diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index bb704f600857ae..8095110678f9a9 100755 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -271,13 +271,13 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 0) inplace_var = self.inplace_api_processing(var) - self.assertEqual(var.inplace_version, 2) + self.assertEqual(var.inplace_version, 1) inplace_var[0] = 2 - self.assertEqual(var.inplace_version, 3) + self.assertEqual(var.inplace_version, 2) inplace_var = self.inplace_api_processing(inplace_var) - self.assertEqual(var.inplace_version, 5) + self.assertEqual(var.inplace_version, 3) def test_backward_error(self): # It raises an error because the inplace operator will result @@ -295,7 +295,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - f"received tensor_version:{2} != wrapper_version_snapshot:{0}", + f"received tensor_version:{1} != wrapper_version_snapshot:{0}", ): loss.backward() @@ -1298,13 +1298,13 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 0) inplace_var = self.inplace_api_processing(var) - self.assertEqual(var.inplace_version, 2) + self.assertEqual(var.inplace_version, 1) inplace_var[0] = 2 - self.assertEqual(var.inplace_version, 3) + self.assertEqual(var.inplace_version, 2) inplace_var = self.inplace_api_processing(inplace_var) - self.assertEqual(var.inplace_version, 5) + self.assertEqual(var.inplace_version, 3) def test_backward_error(self): # It raises an error because the inplace operator will result @@ -1322,7 +1322,7 @@ def test_backward_error(self): loss = paddle.nn.functional.relu(var_c) with self.assertRaisesRegex( RuntimeError, - "received tensor_version:2 != wrapper_version_snapshot:0", + "received tensor_version:1 != wrapper_version_snapshot:0", ): loss.backward() From b4396d95cad2715ec94468db280275538a029af4 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 22 Nov 2024 14:55:36 +0800 Subject: [PATCH 3/4] split code into dynamic and pir mode --- python/paddle/tensor/search.py | 63 +++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 29177d9f8aeb62..a97aec3a8d9aca 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -770,27 +770,56 @@ def where( if x is None or y is None: raise ValueError("either both or neither of x and y should be given") - condition_shape = list(condition.shape) - x_shape = list(x.shape) - y_shape = list(y.shape) + # NOTE: We might need to adapt the broadcast_shape and broadcast_to for dynamic shape + # so dynamic and pir branch can be merged into one code block + if in_dynamic_mode(): + condition_shape = list(condition.shape) + x_shape = list(x.shape) + y_shape = list(y.shape) - broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) - broadcast_shape = paddle.broadcast_shape(broadcast_shape, condition_shape) + broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) + broadcast_shape = paddle.broadcast_shape( + broadcast_shape, condition_shape + ) - broadcast_x = x - broadcast_y = y - broadcast_condition = condition + broadcast_x = x + broadcast_y = y + broadcast_condition = condition - if condition_shape != broadcast_shape: - broadcast_condition = paddle.broadcast_to( - broadcast_condition, broadcast_shape - ) - if x_shape != broadcast_shape: - broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape) - if y_shape != broadcast_shape: - broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape) + if condition_shape != broadcast_shape: + broadcast_condition = paddle.broadcast_to( + broadcast_condition, broadcast_shape + ) + if x_shape != broadcast_shape: + broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape) + if y_shape != broadcast_shape: + broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape) + + return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) + + elif in_pir_mode(): + condition_shape = list(condition.shape) + x_shape = list(x.shape) + y_shape = list(y.shape) + + if x_shape == y_shape and condition_shape == x_shape: + broadcast_condition = condition + broadcast_x = x + broadcast_y = y + else: + zeros_like_x = paddle.zeros_like(x) + zeros_like_y = paddle.zeros_like(y) + zeros_like_condition = paddle.zeros_like(condition) + zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype) + cast_cond = paddle.cast(condition, x.dtype) + + broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y) + broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition) + broadcast_x = paddle.add(x, broadcast_zeros) + broadcast_y = paddle.add(y, broadcast_zeros) + broadcast_condition = paddle.add(cast_cond, broadcast_zeros) + broadcast_condition = paddle.cast(broadcast_condition, 'bool') - if in_dynamic_or_pir_mode(): return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) else: check_variable_and_dtype(condition, 'condition', ['bool'], 'where') From 23f1da3423fcebd2e151756129e3f035ab25369c Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 27 Nov 2024 15:35:23 +0800 Subject: [PATCH 4/4] fix where for pir/old ir bug --- python/paddle/tensor/search.py | 72 +++++++++++++++++----------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index a97aec3a8d9aca..0d75bb92a38130 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -772,11 +772,11 @@ def where( # NOTE: We might need to adapt the broadcast_shape and broadcast_to for dynamic shape # so dynamic and pir branch can be merged into one code block - if in_dynamic_mode(): - condition_shape = list(condition.shape) - x_shape = list(x.shape) - y_shape = list(y.shape) + condition_shape = list(condition.shape) + x_shape = list(x.shape) + y_shape = list(y.shape) + if in_dynamic_mode(): broadcast_shape = paddle.broadcast_shape(x_shape, y_shape) broadcast_shape = paddle.broadcast_shape( broadcast_shape, condition_shape @@ -797,11 +797,8 @@ def where( return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) - elif in_pir_mode(): - condition_shape = list(condition.shape) - x_shape = list(x.shape) - y_shape = list(y.shape) - + else: + # for PIR and old IR if x_shape == y_shape and condition_shape == x_shape: broadcast_condition = condition broadcast_x = x @@ -820,35 +817,36 @@ def where( broadcast_condition = paddle.add(cast_cond, broadcast_zeros) broadcast_condition = paddle.cast(broadcast_condition, 'bool') - return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) - else: - check_variable_and_dtype(condition, 'condition', ['bool'], 'where') - check_variable_and_dtype( - x, - 'x', - ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'where', - ) - check_variable_and_dtype( - y, - 'y', - ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'where', - ) - helper = LayerHelper("where", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type='where', - inputs={ - 'Condition': broadcast_condition, - 'X': broadcast_x, - 'Y': broadcast_y, - }, - outputs={'Out': [out]}, - ) + if in_pir_mode(): + return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) + else: + check_variable_and_dtype(condition, 'condition', ['bool'], 'where') + check_variable_and_dtype( + x, + 'x', + ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'where', + ) + check_variable_and_dtype( + y, + 'y', + ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'where', + ) + helper = LayerHelper("where", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='where', + inputs={ + 'Condition': broadcast_condition, + 'X': broadcast_x, + 'Y': broadcast_y, + }, + outputs={'Out': [out]}, + ) - return out + return out @inplace_apis_in_dygraph_only