-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383
Changes from 6 commits
c484f02
309e049
2c060b1
f1dcb9c
881af9b
ef21a94
43b65f3
6403dc4
53a5b32
6f4485b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4034,6 +4034,91 @@ def get_attr_shape(list_shape): | |
return out | ||
|
||
|
||
def masked_scatter(x, mask, value, name=None): | ||
""" | ||
Copies elements from `value` into `x` tensor at positions where the `mask` is True. | ||
|
||
Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for | ||
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor. | ||
The `value` should have at least as many elements as the number of ones in `mask`. | ||
|
||
Args: | ||
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int32``, | ||
``int64`` or ``bfloat16``. | ||
mask (Tensor): The boolean tensor indicate the position to be filled. | ||
The data type of mask must be bool. | ||
value (Tensor): The value used to fill the target tensor. | ||
Supported data types are same as x. | ||
name (str, optional): Name for the operation (optional, default is None). For more information, | ||
please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor, A reshaped Tensor with the same data type as ``x``. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> x = paddle.randn([2, 2]) | ||
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
[[-0.22560297, -1.48127186], | ||
[-1.12356818, 0.76581174]]) | ||
>>> mask = paddle.randn([2, 2]) | ||
>>> mask = mask>0.6 | ||
Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True, | ||
[[False, True ], | ||
[True, False]]) | ||
>>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32") | ||
|
||
>>> out = paddle.masked_scatter(x, mask, value) | ||
>>> print(out) | ||
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
[[-0.22560297, 1.], | ||
[2, 0.76581174]]) | ||
|
||
""" | ||
# make sure the dtype of x and value is the same | ||
assert ( | ||
x.dtype == value.dtype | ||
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' | ||
assert mask.dtype == paddle.bool | ||
|
||
zeros_like_x = paddle.zeros_like(x, dtype=int) | ||
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) | ||
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) | ||
assert ( | ||
mask_prefix[-1] <= value.numel() | ||
), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}' | ||
|
||
value = value.flatten()[mask_prefix].reshape(mask.shape) | ||
mask = paddle.logical_not(mask) | ||
return paddle.where(mask, x, value) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 实现逻辑上, 理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改为两个 API 的逻辑一致,性能只有较为轻微的差异 |
||
|
||
@inplace_apis_in_dygraph_only | ||
def masked_scatter_(x, mask, value, name=None): | ||
""" | ||
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``. | ||
Please refer to :ref:`api_paddle_masked_scatter`. | ||
""" | ||
assert ( | ||
x.dtype == value.dtype | ||
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' | ||
assert mask.dtype == paddle.bool | ||
zeros_like_x = paddle.zeros_like(x, dtype=int) | ||
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) | ||
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) | ||
assert ( | ||
mask_prefix[-1] <= value.numel() | ||
), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}' | ||
|
||
value = value.flatten()[mask_prefix].reshape(mask.shape) | ||
mask = paddle.logical_not(mask) | ||
out = paddle.where_(mask, x, value) | ||
return out | ||
|
||
|
||
@inplace_apis_in_dygraph_only | ||
def reshape_(x, shape, name=None): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
估计得加
seed
固定输出,否则示例检查过不了,参考 固定的输出优于随机There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例代码检查没过,log 如下,可能是因为
x = paddle.randn([2, 2])
没有 print(x),导致没有获得输出。代码示例部分也都注意一下这吧~ 可以本地跑一下试试There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed 固定的情况下 CI 和 本地的结果不同,我这里又重新改了一次。