Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Optimize paddle.where and paddle.where_ in eager mode #69556

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 82 additions & 58 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,58 +770,83 @@ def where(
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")

# NOTE: We might need to adapt the broadcast_shape and broadcast_to for dynamic shape
# so dynamic and pir branch can be merged into one code block
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)

if x_shape == y_shape and condition_shape == x_shape:
broadcast_condition = condition
if in_dynamic_mode():
broadcast_shape = paddle.broadcast_shape(x_shape, y_shape)
broadcast_shape = paddle.broadcast_shape(
broadcast_shape, condition_shape
)

broadcast_x = x
broadcast_y = y
else:
zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition)
zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
cast_cond = paddle.cast(condition, x.dtype)

broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
broadcast_x = paddle.add(x, broadcast_zeros)
broadcast_y = paddle.add(y, broadcast_zeros)
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool')
broadcast_condition = condition

if condition_shape != broadcast_shape:
broadcast_condition = paddle.broadcast_to(
broadcast_condition, broadcast_shape
)
if x_shape != broadcast_shape:
broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape)
if y_shape != broadcast_shape:
broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape)

if in_dynamic_or_pir_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)

else:
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
'where',
)
check_variable_and_dtype(
y,
'y',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
'where',
)
helper = LayerHelper("where", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
# for PIR and old IR
if x_shape == y_shape and condition_shape == x_shape:
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
else:
zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition)
zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
cast_cond = paddle.cast(condition, x.dtype)

broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
broadcast_x = paddle.add(x, broadcast_zeros)
broadcast_y = paddle.add(y, broadcast_zeros)
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool')

helper.append_op(
type='where',
inputs={
'Condition': broadcast_condition,
'X': broadcast_x,
'Y': broadcast_y,
},
outputs={'Out': [out]},
)
if in_pir_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
else:
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
'where',
)
check_variable_and_dtype(
y,
'y',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
'where',
)
helper = LayerHelper("where", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)

helper.append_op(
type='where',
inputs={
'Condition': broadcast_condition,
'X': broadcast_x,
'Y': broadcast_y,
},
outputs={'Out': [out]},
)

return out
return out


@inplace_apis_in_dygraph_only
Expand All @@ -844,23 +869,22 @@ def where_(
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)
if x_shape == y_shape and condition_shape == x_shape:
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
else:
zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition)
zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
cast_cond = paddle.cast(condition, x.dtype)

broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
broadcast_x = x.add_(broadcast_zeros)
broadcast_y = paddle.add(y, broadcast_zeros)
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool')

broadcast_shape = paddle.broadcast_shape(x_shape, y_shape)
broadcast_shape = paddle.broadcast_shape(broadcast_shape, condition_shape)

broadcast_x = x
broadcast_y = y
broadcast_condition = condition

if condition_shape != broadcast_shape:
broadcast_condition = paddle.broadcast_to(
broadcast_condition, broadcast_shape
)
if x_shape != broadcast_shape:
broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape)
if y_shape != broadcast_shape:
broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape)

if in_dynamic_mode():
return _C_ops.where_(broadcast_condition, broadcast_x, broadcast_y)
Expand Down
16 changes: 8 additions & 8 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ def test_forward_version(self):
self.assertEqual(var.inplace_version, 0)

inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 2)
self.assertEqual(var.inplace_version, 1)

inplace_var[0] = 2
self.assertEqual(var.inplace_version, 3)
self.assertEqual(var.inplace_version, 2)

inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 5)
self.assertEqual(var.inplace_version, 3)

def test_backward_error(self):
# It raises an error because the inplace operator will result
Expand All @@ -295,7 +295,7 @@ def test_backward_error(self):
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegex(
RuntimeError,
f"received tensor_version:{2} != wrapper_version_snapshot:{0}",
f"received tensor_version:{1} != wrapper_version_snapshot:{0}",
):
loss.backward()

Expand Down Expand Up @@ -1298,13 +1298,13 @@ def test_forward_version(self):
self.assertEqual(var.inplace_version, 0)

inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 2)
self.assertEqual(var.inplace_version, 1)

inplace_var[0] = 2
self.assertEqual(var.inplace_version, 3)
self.assertEqual(var.inplace_version, 2)

inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 5)
self.assertEqual(var.inplace_version, 3)

def test_backward_error(self):
# It raises an error because the inplace operator will result
Expand All @@ -1322,7 +1322,7 @@ def test_backward_error(self):
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegex(
RuntimeError,
"received tensor_version:2 != wrapper_version_snapshot:0",
"received tensor_version:1 != wrapper_version_snapshot:0",
):
loss.backward()

Expand Down
94 changes: 91 additions & 3 deletions test/legacy_test/test_where_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,15 @@ def test_api_broadcast(self, use_cuda=False):
with paddle.static.program_guard(main_program):
x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32')
y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32')
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32')
y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype(
'float32'
x_i = (
np.array([[0.9383, 0.1983, 3.2, 1.2]])
.astype('float32')
.reshape([1, 4, 1])
)
y_i = (
np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]])
.astype('float32')
.reshape([1, 4, 2])
)
result = paddle.where((x > 1), x=x, y=y)
for use_cuda in [False, True]:
Expand Down Expand Up @@ -805,6 +811,88 @@ def test_where_condition(self):
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)


class TestWhereDygraphAPIBroadcast(unittest.TestCase):
def test_broadcast_scalar(self):
with base.dygraph.guard():
x_i = np.random.randn(4, 5, 6).astype('float64')
y_i = -1.0
cond_i = np.random.randn(1, 1, 6).astype('bool')
x = paddle.to_tensor(x_i)
y = paddle.to_tensor(y_i)
cond = paddle.to_tensor(cond_i)
out = paddle.where(cond, x, y)
np.testing.assert_array_equal(
out.numpy(), np.where(cond_i, x_i, y_i)
)

def test_broadcast_to_x(self):
with base.dygraph.guard():
x_i = np.random.randn(4, 5, 6).astype('float64')
y_i = np.random.randn(1, 5, 6).astype('float64')
cond_i = np.random.randn(1, 1, 6).astype('bool')
x = paddle.to_tensor(x_i)
y = paddle.to_tensor(y_i)
cond = paddle.to_tensor(cond_i)
out = paddle.where(cond, x, y)
np.testing.assert_array_equal(
out.numpy(), np.where(cond_i, x_i, y_i)
)

def test_broadcast_to_y(self):
with base.dygraph.guard():
x_i = np.random.randn(1, 5, 6).astype('float64')
y_i = np.random.randn(4, 5, 6).astype('float64')
cond_i = np.random.randn(1, 1, 6).astype('bool')
x = paddle.to_tensor(x_i)
y = paddle.to_tensor(y_i)
cond = paddle.to_tensor(cond_i)
out = paddle.where(cond, x, y)
np.testing.assert_array_equal(
out.numpy(), np.where(cond_i, x_i, y_i)
)

def test_broadcast_to_cond(self):
with base.dygraph.guard():
x_i = np.random.randn(1, 1, 6).astype('float64')
y_i = np.random.randn(1, 5, 1).astype('float64')
cond_i = np.random.randn(4, 5, 6).astype('bool')
x = paddle.to_tensor(x_i)
y = paddle.to_tensor(y_i)
cond = paddle.to_tensor(cond_i)
out = paddle.where(cond, x, y)
np.testing.assert_array_equal(
out.numpy(), np.where(cond_i, x_i, y_i)
)

def test_can_not_broadcast(self):
with base.dygraph.guard():
x_i = np.random.randn(1, 1, 6).astype('float64')
y_i = np.random.randn(1, 5, 3).astype('float64')
cond_i = np.random.randn(4, 5, 6).astype('bool')
x = paddle.to_tensor(x_i)
y = paddle.to_tensor(y_i)
cond = paddle.to_tensor(cond_i)

with self.assertRaises(ValueError):
_ = paddle.where(cond, x, y)


class TestWhereDygraphAPIDtypePromotion(unittest.TestCase):
def test_dtype_auto_promotion_float(self):
with base.dygraph.guard():
x_i = np.random.randn(4, 5, 6).astype('float32')
y_i = np.random.randn(4, 5, 6).astype('float64')
cond_i = np.random.randn(4, 5, 6).astype('bool')
x = paddle.to_tensor(x_i)
y = paddle.to_tensor(y_i)
cond = paddle.to_tensor(cond_i)
out = paddle.where(cond, x, y)
self.assertEqual(out.dtype, y.dtype)
np.testing.assert_array_equal(
out.numpy(), np.where(cond_i, x_i, y_i)
)


class TestWhereOpError(unittest.TestCase):
def test_errors(self):
with paddle.static.program_guard(
Expand Down
12 changes: 9 additions & 3 deletions test/xpu/test_where_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,15 @@ def test_api_broadcast(self, use_cuda=False):
with base.program_guard(train_prog, startup):
x = paddle.static.data(name='x', shape=[-1, 4, 1], dtype='float32')
y = paddle.static.data(name='y', shape=[-1, 4, 2], dtype='float32')
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
y_i = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype(
"float32"
x_i = (
np.array([[0.9383, 0.1983, 3.2, 1.2]])
.astype("float32")
.reshape([1, 4, 1])
)
y_i = (
np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]])
.astype("float32")
.reshape([1, 4, 2])
)
result = paddle.where(x > 1, x=x, y=y)

Expand Down