From 8d717f368640b4d5c6b9eaf1ecb09e05b230b002 Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Wed, 13 Dec 2023 12:12:32 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No.4=E3=80=91=E4=B8=BA=20P?= =?UTF-8?q?addle=20=E6=96=B0=E5=A2=9E=20masked=5Fscatter=20API=20-part=20(?= =?UTF-8?q?#59383)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/__init__.py | 4 + python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/manipulation.py | 83 +++++++ test/legacy_test/test_inplace.py | 16 ++ test/legacy_test/test_masked_scatter.py | 314 ++++++++++++++++++++++++ 5 files changed, 421 insertions(+) create mode 100644 test/legacy_test/test_masked_scatter.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 67b25a0e98ff90..dc6beceb0eb6aa 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -286,6 +286,8 @@ index_put_, masked_fill, masked_fill_, + masked_scatter_, + masked_scatter, moveaxis, put_along_axis, select_scatter, @@ -929,6 +931,8 @@ 'polygamma_', 'masked_fill', 'masked_fill_', + 'masked_scatter', + 'masked_scatter_', 'hypot', 'hypot_', 'index_fill', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 267160a7a227d7..eab6b636e85555 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -169,6 +169,8 @@ index_put_, masked_fill, masked_fill_, + masked_scatter, + masked_scatter_, moveaxis, put_along_axis, put_along_axis_, @@ -778,6 +780,8 @@ 'atleast_2d', 'atleast_3d', 'diagonal_scatter', + 'masked_scatter', + 'masked_scatter_', "combinations", ] diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 39c196dffbfa76..c6be5c0c946b47 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4534,6 +4534,89 @@ def get_attr_shape(list_shape): return out +def masked_scatter(x, mask, value, name=None): + """ + Copies elements from `value` into `x` tensor at positions where the `mask` is True. + + Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for + each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor. + The `value` should have at least as many elements as the number of ones in `mask`. + + Args: + x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int32``, + ``int64`` or ``bfloat16``. + mask (Tensor): The boolean tensor indicate the position to be filled. + The data type of mask must be bool. + value (Tensor): The value used to fill the target tensor. + Supported data types are same as x. + name (str, optional): Name for the operation (optional, default is None). For more information, + please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, A reshaped Tensor with the same data type as ``x``. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.seed(2048) + >>> x = paddle.randn([2, 2]) + >>> print(x) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[-1.24725831, 0.03843464], + [-0.31660911, 0.04793844]]) + + >>> mask = paddle.to_tensor([[True, True], [False, False]]) + >>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32") + + >>> out = paddle.masked_scatter(x, mask, value) + >>> print(out) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[1, 2], + [-0.31660911, 0.04793844]]) + + """ + # make sure the dtype of x and value is the same + assert ( + x.dtype == value.dtype + ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' + assert mask.dtype == paddle.bool + + zeros_like_x = paddle.zeros_like(x, dtype=int) + mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) + mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) + assert ( + mask_prefix[-1] <= value.numel() + ), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}' + + value = value.flatten()[mask_prefix].reshape(mask.shape) + mask = paddle.logical_not(mask) + return paddle.where(mask, x, value) + + +@inplace_apis_in_dygraph_only +def masked_scatter_(x, mask, value, name=None): + """ + Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_masked_scatter`. + """ + assert ( + x.dtype == value.dtype + ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' + assert mask.dtype == paddle.bool + zeros_like_x = paddle.zeros_like(x, dtype=int) + mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) + mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) + assert ( + mask_prefix[-1] <= value.numel() + ), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}' + + value = value.flatten()[mask_prefix].reshape(mask.shape) + mask = paddle.logical_not(mask) + out = paddle.where_(mask, x, value) + return out + + @inplace_apis_in_dygraph_only def reshape_(x, shape, name=None): """ diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 34b50d48d76ac4..c2ced2f6db704d 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -316,6 +316,22 @@ def init_data(self): self.mask = paddle.to_tensor(self.mask, dtype='bool') +class TestDygraphInplaceMaskedScatter(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.masked_scatter(var, self.mask, self.value) + + def inplace_api_processing(self, var): + return paddle.masked_scatter_(var, self.mask, self.value) + + def init_data(self): + self.dtype = "float32" + self.input_var_numpy = np.random.uniform(-5, 5, [30, 3]) + self.value = np.random.uniform(size=(30, 30)) + self.value = paddle.to_tensor(self.value, dtype=self.dtype) + self.mask = np.random.randint(0, 2, [30, 1]).astype('bool') + self.mask = paddle.to_tensor(self.mask, dtype='bool') + + class TestDygraphInplaceWithContinuous(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) diff --git a/test/legacy_test/test_masked_scatter.py b/test/legacy_test/test_masked_scatter.py new file mode 100644 index 00000000000000..482f756884f0dd --- /dev/null +++ b/test/legacy_test/test_masked_scatter.py @@ -0,0 +1,314 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import convert_float_to_uint16 + +import paddle +from paddle import base +from paddle.base import core + + +def np_masked_scatter(x, mask, value): + x, mask = np.broadcast_arrays(x, mask) + mask_prefix_sum = np.clip(mask.cumsum() - 1, a_min=0, a_max=None) + value = value.flatten()[mask_prefix_sum].reshape(x.shape) + return np.where(mask, value, x) + + +paddle.enable_static() + + +class TestMaskedScatterError(unittest.TestCase): + def setUp(self): + self.init() + + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype='bool' + ) + + self.value_np = np.random.randn(*self.value_shape).astype(self.dtype) + + def init(self): + self.x_shape = (50, 3) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + def test_mask_error(self): + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('int32') + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + + with np.testing.assert_raises(AssertionError): + paddle.masked_scatter(x, mask, value) + + def test_dtype_error(self): + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + value = paddle.to_tensor(self.value_np, dtype='float64') + with np.testing.assert_raises(AssertionError): + paddle.masked_scatter(x, mask, value) + + def test_numel_error(self): + self.value_np = np.random.randn(5, 5).astype(self.dtype) + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + with np.testing.assert_raises(AssertionError): + paddle.masked_scatter(x, mask, value) + + +class TestMaskedScatterAPI(unittest.TestCase): + def setUp(self): + self.init() + + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = np.random.randn(*self.value_shape).astype(self.dtype) + self.out_np = np_masked_scatter(self.x_np, self.mask_np, self.value_np) + + def init(self): + self.x_shape = (50, 3) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + def test_static_graph(self): + paddle.enable_static() + startup_program = base.Program() + train_program = base.Program() + with base.program_guard(startup_program, train_program): + x = paddle.static.data( + name='x', dtype=self.dtype, shape=self.x_shape + ) + mask = paddle.static.data( + name='mask', dtype='bool', shape=self.mask_shape + ) + value = paddle.static.data( + name='value', dtype=self.dtype, shape=self.value_np.shape + ) + out = paddle.masked_scatter(x, mask, value) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + res = exe.run( + base.default_main_program(), + feed={ + 'x': self.x_np, + 'mask': self.mask_np, + 'value': self.value_np, + }, + fetch_list=[out], + ) + np.testing.assert_allclose( + res[0], self.out_np, atol=1e-5, rtol=1e-5 + ) + paddle.disable_static() + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + result = paddle.masked_scatter(x, mask, value) + np.testing.assert_allclose(self.out_np, result.numpy(), rtol=1e-05) + + paddle.enable_static() + + +class TestMaskedScatterAPI1(TestMaskedScatterAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPI2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPI3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16API1(TestMaskedScatterAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16API2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16API3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast(TestMaskedScatterAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (3, 3) + self.mask_shape = (1, 3) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (120,) + self.mask_shape = (300, 120) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast4(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast5(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16APIBroadcast(TestMaskedScatterAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16APIBroadcast2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16APIBroadcast3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedScatterBF16(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 1) + self.dtype = "uint16" + self.value_shape = (300, 300) + + def setUp(self): + self.init() + + self.x_np = convert_float_to_uint16( + np.random.random(self.x_shape).astype("float32") + ) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = convert_float_to_uint16( + np.random.randn(*self.value_shape).astype("float32") + ) + self.out_np = np_masked_scatter(self.x_np, self.mask_np, self.value_np) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedScatterBF16APIBroadcast2(TestMaskedScatterBF16): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 3) + self.dtype = "uint16" + self.value_shape = (300, 300) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()