Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383

Merged
merged 10 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@
unfold,
masked_fill,
masked_fill_,
masked_scatter_,
masked_scatter,
index_fill,
index_fill_,
diagonal_scatter,
Expand Down Expand Up @@ -928,6 +930,8 @@
'polygamma_',
'masked_fill',
'masked_fill_',
'masked_scatter',
'masked_scatter_',
'hypot',
'hypot_',
'index_fill',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
index_put_,
masked_fill,
masked_fill_,
masked_scatter,
masked_scatter_,
moveaxis,
put_along_axis,
put_along_axis_,
Expand Down Expand Up @@ -759,6 +761,8 @@
'atleast_2d',
'atleast_3d',
'diagonal_scatter',
'masked_scatter',
'masked_scatter_',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
79 changes: 79 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4030,6 +4030,85 @@ def get_attr_shape(list_shape):
return out


def masked_scatter(x, mask, value, name=None):
"""
Copies elements from `value` into `x` tensor at positions where the `mask` is True.

Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor.
The `value` should have at least as many elements as the number of ones in `mask`.

Args:
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``.
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor .
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, A reshaped Tensor with the same data type as ``x``.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.randn([3,4])
>>> mask = paddle.randn([3,4])
>>> mask = mask>0.6
>>> value = paddle.ones([2,4], dtype="float32")

>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to give the values of x, mask and out in example code to make it easier for users to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


"""
# make sure the dtype of x and source is the same
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool

zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
return paddle.where(paddle.cast(mask, dtype="bool"), value, x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现逻辑上,masked_scatter masked_scatter_两个API有不同的实现逻辑。是某些操作在静态图下不支持,又希望masked_scatter 使用同一套代码的原因吗。

理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为两个 API 的逻辑一致,性能只有较为轻微的差异


@inplace_apis_in_dygraph_only
def masked_scatter_(x, mask, value, name=None):
"""
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_scatter`.
"""
# make sure the dtype of x and source is the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不必要的注释可移除掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool
shape = paddle.broadcast_shape(x.shape, mask.shape)
if shape != x.shape:
raise ValueError(
"The shape of broadcast output should be equal to inplace tensor shape in the Inplace operation, but got output shape {} and inplace tensor {}.".format(
shape, x.shape
)
)
mask = paddle.broadcast_to(mask, shape)
mask_prefix = paddle.clip(mask.astype(int).cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
out = paddle.where_(mask, x, value)
return out


@inplace_apis_in_dygraph_only
def reshape_(x, shape, name=None):
"""
Expand Down
16 changes: 16 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,22 @@ def init_data(self):
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceMaskedScatter(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.masked_scatter(var, self.mask, self.value)

def inplace_api_processing(self, var):
return paddle.masked_scatter_(var, self.mask, self.value)

def init_data(self):
self.dtype = "float32"
self.input_var_numpy = np.random.uniform(-5, 5, [30, 3])
self.value = np.random.uniform(size=(30, 30))
self.value = paddle.to_tensor(self.value, dtype=self.dtype)
self.mask = np.random.randint(0, 2, [30, 1]).astype('bool')
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
Expand Down
Loading