Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383

Merged
merged 10 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@
index_put_,
masked_fill,
masked_fill_,
masked_scatter_,
masked_scatter,
moveaxis,
put_along_axis,
select_scatter,
Expand Down Expand Up @@ -913,6 +915,8 @@
'polygamma_',
'masked_fill',
'masked_fill_',
'masked_scatter',
'masked_scatter_',
'hypot',
'hypot_',
'index_fill',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@
index_put_,
masked_fill,
masked_fill_,
masked_scatter,
masked_scatter_,
moveaxis,
put_along_axis,
put_along_axis_,
Expand Down Expand Up @@ -767,6 +769,8 @@
'atleast_2d',
'atleast_3d',
'diagonal_scatter',
'masked_scatter',
'masked_scatter_',
"combinations",
]

Expand Down
83 changes: 83 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4033,6 +4033,89 @@ def get_attr_shape(list_shape):
return out


def masked_scatter(x, mask, value, name=None):
"""
Copies elements from `value` into `x` tensor at positions where the `mask` is True.

Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor.
The `value` should have at least as many elements as the number of ones in `mask`.

Args:
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int32``,
``int64`` or ``bfloat16``.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of mask must be bool.
value (Tensor): The value used to fill the target tensor.
Supported data types are same as x.
name (str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.

Returns:
Tensor, A reshaped Tensor with the same data type as ``x``.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.seed(2048)
>>> x = paddle.randn([2, 2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

估计得加 seed 固定输出,否则示例检查过不了,参考 固定的输出优于随机

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例代码检查没过,log 如下,可能是因为 x = paddle.randn([2, 2]) 没有 print(x),导致没有获得输出。代码示例部分也都注意一下这吧~ 可以本地跑一下试试
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed 固定的情况下 CI 和 本地的结果不同,我这里又重新改了一次。

>>> print(x)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[-1.24725831, 0.03843464],
[-0.31660911, 0.04793844]])

>>> mask = paddle.to_tensor([[True, True], [False, False]])
>>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32")

>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[1, 2],
[-0.31660911, 0.04793844]])

"""
# make sure the dtype of x and value is the same
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool

zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
return paddle.where(mask, x, value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现逻辑上,masked_scatter masked_scatter_两个API有不同的实现逻辑。是某些操作在静态图下不支持,又希望masked_scatter 使用同一套代码的原因吗。

理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为两个 API 的逻辑一致,性能只有较为轻微的差异


@inplace_apis_in_dygraph_only
def masked_scatter_(x, mask, value, name=None):
"""
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_scatter`.
"""
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool
zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
out = paddle.where_(mask, x, value)
return out


@inplace_apis_in_dygraph_only
def reshape_(x, shape, name=None):
"""
Expand Down
16 changes: 16 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,22 @@ def init_data(self):
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceMaskedScatter(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.masked_scatter(var, self.mask, self.value)

def inplace_api_processing(self, var):
return paddle.masked_scatter_(var, self.mask, self.value)

def init_data(self):
self.dtype = "float32"
self.input_var_numpy = np.random.uniform(-5, 5, [30, 3])
self.value = np.random.uniform(size=(30, 30))
self.value = paddle.to_tensor(self.value, dtype=self.dtype)
self.mask = np.random.randint(0, 2, [30, 1]).astype('bool')
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
Expand Down
Loading