Skip to content

Commit

Permalink
split code into dynamic and pir mode
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 22, 2024
1 parent b157596 commit b4396d9
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,27 +770,56 @@ def where(
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")

condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)
# NOTE: We might need to adapt the broadcast_shape and broadcast_to for dynamic shape
# so dynamic and pir branch can be merged into one code block
if in_dynamic_mode():
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)

broadcast_shape = paddle.broadcast_shape(x_shape, y_shape)
broadcast_shape = paddle.broadcast_shape(broadcast_shape, condition_shape)
broadcast_shape = paddle.broadcast_shape(x_shape, y_shape)
broadcast_shape = paddle.broadcast_shape(
broadcast_shape, condition_shape
)

broadcast_x = x
broadcast_y = y
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
broadcast_condition = condition

if condition_shape != broadcast_shape:
broadcast_condition = paddle.broadcast_to(
broadcast_condition, broadcast_shape
)
if x_shape != broadcast_shape:
broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape)
if y_shape != broadcast_shape:
broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape)
if condition_shape != broadcast_shape:
broadcast_condition = paddle.broadcast_to(
broadcast_condition, broadcast_shape
)
if x_shape != broadcast_shape:
broadcast_x = paddle.broadcast_to(broadcast_x, broadcast_shape)
if y_shape != broadcast_shape:
broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape)

return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)

elif in_pir_mode():
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)

if x_shape == y_shape and condition_shape == x_shape:
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
else:
zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition)
zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
cast_cond = paddle.cast(condition, x.dtype)

broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
broadcast_x = paddle.add(x, broadcast_zeros)
broadcast_y = paddle.add(y, broadcast_zeros)
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool')

if in_dynamic_or_pir_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
else:
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
Expand Down

0 comments on commit b4396d9

Please sign in to comment.