Skip to content

Commit

Permalink
【Hackathon No.4】为 Paddle 新增 masked_scatter API -part (#59383)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao authored Dec 13, 2023
1 parent a8d5117 commit 8d717f3
Show file tree
Hide file tree
Showing 5 changed files with 421 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@
index_put_,
masked_fill,
masked_fill_,
masked_scatter_,
masked_scatter,
moveaxis,
put_along_axis,
select_scatter,
Expand Down Expand Up @@ -929,6 +931,8 @@
'polygamma_',
'masked_fill',
'masked_fill_',
'masked_scatter',
'masked_scatter_',
'hypot',
'hypot_',
'index_fill',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@
index_put_,
masked_fill,
masked_fill_,
masked_scatter,
masked_scatter_,
moveaxis,
put_along_axis,
put_along_axis_,
Expand Down Expand Up @@ -778,6 +780,8 @@
'atleast_2d',
'atleast_3d',
'diagonal_scatter',
'masked_scatter',
'masked_scatter_',
"combinations",
]

Expand Down
83 changes: 83 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4534,6 +4534,89 @@ def get_attr_shape(list_shape):
return out


def masked_scatter(x, mask, value, name=None):
"""
Copies elements from `value` into `x` tensor at positions where the `mask` is True.
Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor.
The `value` should have at least as many elements as the number of ones in `mask`.
Args:
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int32``,
``int64`` or ``bfloat16``.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of mask must be bool.
value (Tensor): The value used to fill the target tensor.
Supported data types are same as x.
name (str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.
Returns:
Tensor, A reshaped Tensor with the same data type as ``x``.
Examples:
.. code-block:: python
>>> import paddle
>>> paddle.seed(2048)
>>> x = paddle.randn([2, 2])
>>> print(x)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[-1.24725831, 0.03843464],
[-0.31660911, 0.04793844]])
>>> mask = paddle.to_tensor([[True, True], [False, False]])
>>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32")
>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[1, 2],
[-0.31660911, 0.04793844]])
"""
# make sure the dtype of x and value is the same
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool

zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
return paddle.where(mask, x, value)


@inplace_apis_in_dygraph_only
def masked_scatter_(x, mask, value, name=None):
"""
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_scatter`.
"""
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool
zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
out = paddle.where_(mask, x, value)
return out


@inplace_apis_in_dygraph_only
def reshape_(x, shape, name=None):
"""
Expand Down
16 changes: 16 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,22 @@ def init_data(self):
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceMaskedScatter(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.masked_scatter(var, self.mask, self.value)

def inplace_api_processing(self, var):
return paddle.masked_scatter_(var, self.mask, self.value)

def init_data(self):
self.dtype = "float32"
self.input_var_numpy = np.random.uniform(-5, 5, [30, 3])
self.value = np.random.uniform(size=(30, 30))
self.value = paddle.to_tensor(self.value, dtype=self.dtype)
self.mask = np.random.randint(0, 2, [30, 1]).astype('bool')
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
Expand Down
Loading

0 comments on commit 8d717f3

Please sign in to comment.