-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383
Changes from 3 commits
c484f02
309e049
2c060b1
f1dcb9c
881af9b
ef21a94
43b65f3
6403dc4
53a5b32
6f4485b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4030,6 +4030,85 @@ def get_attr_shape(list_shape): | |
return out | ||
|
||
|
||
def masked_scatter(x, mask, value, name=None): | ||
""" | ||
Copies elements from `value` into `x` tensor at positions where the `mask` is True. | ||
|
||
Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for | ||
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor. | ||
The `value` should have at least as many elements as the number of ones in `mask`. | ||
|
||
Args: | ||
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``. | ||
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1. | ||
The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape []. | ||
If ``shape`` is an Tensor, it should be an 1-D Tensor . | ||
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor, A reshaped Tensor with the same data type as ``x``. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> x = paddle.randn([3,4]) | ||
>>> mask = paddle.randn([3,4]) | ||
>>> mask = mask>0.6 | ||
>>> value = paddle.ones([2,4], dtype="float32") | ||
|
||
>>> out = paddle.masked_scatter(x, mask, value) | ||
>>> print(out.shape) | ||
|
||
""" | ||
# make sure the dtype of x and source is the same | ||
assert ( | ||
x.dtype == value.dtype | ||
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' | ||
assert mask.dtype == paddle.bool | ||
|
||
zeros_like_x = paddle.zeros_like(x, dtype=int) | ||
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) | ||
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) | ||
assert ( | ||
mask_prefix[-1] <= value.numel() | ||
), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}' | ||
|
||
value = value.flatten()[mask_prefix].reshape(mask.shape) | ||
return paddle.where(paddle.cast(mask, dtype="bool"), value, x) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 实现逻辑上, 理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改为两个 API 的逻辑一致,性能只有较为轻微的差异 |
||
|
||
@inplace_apis_in_dygraph_only | ||
def masked_scatter_(x, mask, value, name=None): | ||
""" | ||
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``. | ||
Please refer to :ref:`api_paddle_masked_scatter`. | ||
""" | ||
# make sure the dtype of x and source is the same | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不必要的注释可移除掉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert ( | ||
x.dtype == value.dtype | ||
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' | ||
assert mask.dtype == paddle.bool | ||
shape = paddle.broadcast_shape(x.shape, mask.shape) | ||
if shape != x.shape: | ||
raise ValueError( | ||
"The shape of broadcast output should be equal to inplace tensor shape in the Inplace operation, but got output shape {} and inplace tensor {}.".format( | ||
shape, x.shape | ||
) | ||
) | ||
mask = paddle.broadcast_to(mask, shape) | ||
mask_prefix = paddle.clip(mask.astype(int).cumsum() - 1, min=0) | ||
assert ( | ||
mask_prefix[-1] <= value.numel() | ||
), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}' | ||
|
||
value = value.flatten()[mask_prefix].reshape(mask.shape) | ||
mask = paddle.logical_not(mask) | ||
out = paddle.where_(mask, x, value) | ||
return out | ||
|
||
|
||
@inplace_apis_in_dygraph_only | ||
def reshape_(x, shape, name=None): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's best to give the values of
x
,mask
andout
in example code to make it easier for users to understandThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done