Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383

Merged
merged 10 commits into from
Dec 13, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test file
yangguohao committed Nov 27, 2023

Verified

This commit was signed with the committer’s verified signature.
djhi Gildas Garcia
commit 309e04913252a635ef160512422fcd4d35c9f23a
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -258,6 +258,8 @@
unfold,
masked_fill,
masked_fill_,
masked_scatter_,
masked_scatter,
index_fill,
index_fill_,
diagonal_scatter,
@@ -926,6 +928,8 @@
'polygamma_',
'masked_fill',
'masked_fill_',
'masked_scatter',
'masked_scatter_',
'hypot',
'hypot_',
'index_fill',
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -171,6 +171,8 @@
from .manipulation import unfold # noqa: F401
from .manipulation import masked_fill # noqa: F401
from .manipulation import masked_fill_ # noqa: F401
from .manipulation import masked_scatter
from .manipulation import masked_scatter_
from .manipulation import index_fill # noqa: F401
from .manipulation import index_fill_ # noqa: F401
from .manipulation import diagonal_scatter # noqa: F401
@@ -743,6 +745,8 @@
'atleast_2d',
'atleast_3d',
'diagonal_scatter',
'masked_scatter',
'masked_scatter_',
]

# this list used in math_op_patch.py for magic_method bind
44 changes: 35 additions & 9 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
@@ -4009,9 +4009,37 @@ def get_attr_shape(list_shape):
return out


def masked_scatter(x, mask, value):
def masked_scatter(x, mask, value, name=None):
"""
利用现有api实现masked_scatter功能
Copies elements from `value` into `x` tensor at positions where the `mask` is True.

Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor.
The `value` should have at least as many elements as the number of ones in `mask`.

Args:
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``.
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor .
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, A reshaped Tensor with the same data type as ``x``.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.randn([3,4])
>>> mask = paddle.randn([3,4])
>>> mask = mask>0.6
>>> value = paddle.ones([2,4], dtype="float32")

>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to give the values of x, mask and out in example code to make it easier for users to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


"""
# make sure the dtype of x and source is the same
assert (
@@ -4031,19 +4059,16 @@ def masked_scatter(x, mask, value):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现逻辑上,masked_scatter masked_scatter_两个API有不同的实现逻辑。是某些操作在静态图下不支持,又希望masked_scatter 使用同一套代码的原因吗。

理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为两个 API 的逻辑一致,性能只有较为轻微的差异


@inplace_apis_in_dygraph_only
def masked_scatter_(x, mask, value):
def masked_scatter_(x, mask, value, name=None):
"""
利用现有api实现masked_scatter功能
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_scatter`.
"""
# make sure the dtype of x and source is the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不必要的注释可移除掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool
if not x.stop_gradient:
raise ValueError(
"The inplace operation only support stop_gradient=True inplace tensor"
)
shape = paddle.broadcast_shape(x.shape, mask.shape)
if shape != x.shape:
raise ValueError(
@@ -4059,7 +4084,8 @@ def masked_scatter_(x, mask, value):

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
paddle.where_(mask, x, value)
out = paddle.where_(mask, x, value)
return out


@inplace_apis_in_dygraph_only
16 changes: 16 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
@@ -316,6 +316,22 @@ def init_data(self):
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceMaskedScatter(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.masked_scatter(var, self.mask, self.value)

def inplace_api_processing(self, var):
return paddle.masked_scatter_(var, self.mask, self.value)

def init_data(self):
self.dtype = "float32"
self.input_var_numpy = np.random.uniform(-5, 5, [30, 3])
self.value = np.random.uniform(size=(30, 30))
self.value = paddle.to_tensor(self.value, dtype=self.dtype)
self.mask = np.random.randint(0, 2, [30, 1]).astype('bool')
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
Loading