From c484f020bd5fc97fdd6bbf08af22970062cbdad9 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 27 Nov 2023 01:12:13 +0800 Subject: [PATCH 1/7] add masked_scatter --- python/paddle/tensor/manipulation.py | 53 ++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4afff9d3a9ad7..48d594d192fdd 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4009,6 +4009,59 @@ def get_attr_shape(list_shape): return out +def masked_scatter(x, mask, value): + """ + 利用现有api实现masked_scatter功能 + """ + # make sure the dtype of x and source is the same + assert ( + x.dtype == value.dtype + ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' + assert mask.dtype == paddle.bool + + zeros_like_x = paddle.zeros_like(x, dtype=int) + mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) + mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) + assert ( + mask_prefix[-1] <= value.numel() + ), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}' + + value = value.flatten()[mask_prefix].reshape(mask.shape) + return paddle.where(paddle.cast(mask, dtype="bool"), value, x) + + +@inplace_apis_in_dygraph_only +def masked_scatter_(x, mask, value): + """ + 利用现有api实现masked_scatter功能 + """ + # make sure the dtype of x and source is the same + assert ( + x.dtype == value.dtype + ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' + assert mask.dtype == paddle.bool + if not x.stop_gradient: + raise ValueError( + "The inplace operation only support stop_gradient=True inplace tensor" + ) + shape = paddle.broadcast_shape(x.shape, mask.shape) + if shape != x.shape: + raise ValueError( + "The shape of broadcast output should be equal to inplace tensor shape in the Inplace operation, but got output shape {} and inplace tensor {}.".format( + shape, x.shape + ) + ) + mask = paddle.broadcast_to(mask, shape) + mask_prefix = paddle.clip(mask.astype(int).cumsum() - 1, min=0) + assert ( + mask_prefix[-1] <= value.numel() + ), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}' + + value = value.flatten()[mask_prefix].reshape(mask.shape) + mask = paddle.logical_not(mask) + paddle.where_(mask, x, value) + + @inplace_apis_in_dygraph_only def reshape_(x, shape, name=None): """ From 309e04913252a635ef160512422fcd4d35c9f23a Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 27 Nov 2023 11:46:48 +0800 Subject: [PATCH 2/7] add test file --- python/paddle/__init__.py | 4 + python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/manipulation.py | 44 +++- test/legacy_test/test_inplace.py | 16 ++ test/legacy_test/test_masked_scatter.py | 274 ++++++++++++++++++++++++ 5 files changed, 333 insertions(+), 9 deletions(-) create mode 100644 test/legacy_test/test_masked_scatter.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5e87e9514c0e9..9ff4ee975f3cb 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -258,6 +258,8 @@ unfold, masked_fill, masked_fill_, + masked_scatter_, + masked_scatter, index_fill, index_fill_, diagonal_scatter, @@ -926,6 +928,8 @@ 'polygamma_', 'masked_fill', 'masked_fill_', + 'masked_scatter', + 'masked_scatter_', 'hypot', 'hypot_', 'index_fill', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b96045d35faf6..a4e03d14c92bf 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -171,6 +171,8 @@ from .manipulation import unfold # noqa: F401 from .manipulation import masked_fill # noqa: F401 from .manipulation import masked_fill_ # noqa: F401 +from .manipulation import masked_scatter +from .manipulation import masked_scatter_ from .manipulation import index_fill # noqa: F401 from .manipulation import index_fill_ # noqa: F401 from .manipulation import diagonal_scatter # noqa: F401 @@ -743,6 +745,8 @@ 'atleast_2d', 'atleast_3d', 'diagonal_scatter', + 'masked_scatter', + 'masked_scatter_', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 48d594d192fdd..0fcc98ae948b2 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4009,9 +4009,37 @@ def get_attr_shape(list_shape): return out -def masked_scatter(x, mask, value): +def masked_scatter(x, mask, value, name=None): """ - 利用现有api实现masked_scatter功能 + Copies elements from `value` into `x` tensor at positions where the `mask` is True. + + Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for + each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor. + The `value` should have at least as many elements as the number of ones in `mask`. + + Args: + x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``. + shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1. + The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape []. + If ``shape`` is an Tensor, it should be an 1-D Tensor . + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, A reshaped Tensor with the same data type as ``x``. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.randn([3,4]) + >>> mask = paddle.randn([3,4]) + >>> mask = mask>0.6 + >>> value = paddle.ones([2,4], dtype="float32") + + >>> out = paddle.masked_scatter(x, mask, value) + >>> print(out.shape) + """ # make sure the dtype of x and source is the same assert ( @@ -4031,19 +4059,16 @@ def masked_scatter(x, mask, value): @inplace_apis_in_dygraph_only -def masked_scatter_(x, mask, value): +def masked_scatter_(x, mask, value, name=None): """ - 利用现有api实现masked_scatter功能 + Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_masked_scatter`. """ # make sure the dtype of x and source is the same assert ( x.dtype == value.dtype ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' assert mask.dtype == paddle.bool - if not x.stop_gradient: - raise ValueError( - "The inplace operation only support stop_gradient=True inplace tensor" - ) shape = paddle.broadcast_shape(x.shape, mask.shape) if shape != x.shape: raise ValueError( @@ -4059,7 +4084,8 @@ def masked_scatter_(x, mask, value): value = value.flatten()[mask_prefix].reshape(mask.shape) mask = paddle.logical_not(mask) - paddle.where_(mask, x, value) + out = paddle.where_(mask, x, value) + return out @inplace_apis_in_dygraph_only diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 34b50d48d76ac..c2ced2f6db704 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -316,6 +316,22 @@ def init_data(self): self.mask = paddle.to_tensor(self.mask, dtype='bool') +class TestDygraphInplaceMaskedScatter(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.masked_scatter(var, self.mask, self.value) + + def inplace_api_processing(self, var): + return paddle.masked_scatter_(var, self.mask, self.value) + + def init_data(self): + self.dtype = "float32" + self.input_var_numpy = np.random.uniform(-5, 5, [30, 3]) + self.value = np.random.uniform(size=(30, 30)) + self.value = paddle.to_tensor(self.value, dtype=self.dtype) + self.mask = np.random.randint(0, 2, [30, 1]).astype('bool') + self.mask = paddle.to_tensor(self.mask, dtype='bool') + + class TestDygraphInplaceWithContinuous(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) diff --git a/test/legacy_test/test_masked_scatter.py b/test/legacy_test/test_masked_scatter.py new file mode 100644 index 0000000000000..199a7bd38bb40 --- /dev/null +++ b/test/legacy_test/test_masked_scatter.py @@ -0,0 +1,274 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import convert_float_to_uint16 + +import paddle +from paddle import base +from paddle.base import core + + +def np_masked_scatter(x, mask, value): + x, mask = np.broadcast_arrays(x, mask) + mask_prefix_sum = np.clip(mask.cumsum() - 1, a_min=0, a_max=None) + print(mask_prefix_sum, value.size) + value = value.flatten()[mask_prefix_sum].reshape(x.shape) + return np.where(mask, value, x) + + +paddle.enable_static() + + +class TestMaskedScatterAPI(unittest.TestCase): + def setUp(self): + self.init() + + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = np.random.randn(*self.value_shape).astype(self.dtype) + self.out_np = np_masked_scatter(self.x_np, self.mask_np, self.value_np) + + def init(self): + self.x_shape = (50, 3) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + def test_static_graph(self): + paddle.enable_static() + startup_program = base.Program() + train_program = base.Program() + with base.program_guard(startup_program, train_program): + x = paddle.static.data( + name='x', dtype=self.dtype, shape=self.x_shape + ) + mask = paddle.static.data( + name='mask', dtype='bool', shape=self.mask_shape + ) + value = paddle.static.data( + name='value', dtype=self.dtype, shape=self.value_np.shape + ) + out = paddle.masked_scatter(x, mask, value) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + res = exe.run( + base.default_main_program(), + feed={ + 'x': self.x_np, + 'mask': self.mask_np, + 'value': self.value_np, + }, + fetch_list=[out], + ) + np.testing.assert_allclose( + res[0], self.out_np, atol=1e-5, rtol=1e-5 + ) + paddle.disable_static() + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + result = paddle.masked_scatter(x, mask, value) + np.testing.assert_allclose(self.out_np, result.numpy(), rtol=1e-05) + + paddle.enable_static() + + +class TestMaskedScatterAPI1(TestMaskedScatterAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPI2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPI3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16API1(TestMaskedScatterAPI): + def init(self): + self.x_shape = (6, 8, 9, 18) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16API2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16API3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (168,) + self.mask_shape = self.x_shape + self.dtype = "float16" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast(TestMaskedScatterAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (3, 3) + self.mask_shape = (1, 3) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (120,) + self.mask_shape = (300, 120) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast4(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.value_shape = (300, 300) + + +class TestMaskedScatterAPIBroadcast5(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 40) + self.mask_shape = (40,) + self.dtype = "float32" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16APIBroadcast(TestMaskedScatterAPI): + def init(self): + self.x_shape = (3, 40) + self.mask_shape = (3, 1) + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16APIBroadcast2(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMaskedScatterFP16APIBroadcast3(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 40) + self.dtype = "float16" + self.value_shape = (300, 300) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedScatterBF16(TestMaskedScatterAPI): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 1) + self.dtype = "uint16" + self.value_shape = (300, 300) + + def setUp(self): + self.init() + + self.x_np = convert_float_to_uint16( + np.random.random(self.x_shape).astype("float32") + ) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype="bool" + ) + + self.value_np = convert_float_to_uint16( + np.random.randn(*self.value_shape).astype("float32") + ) + self.out_np = np_masked_scatter(self.x_np, self.mask_np, self.value_np) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedScatterBF16APIBroadcast2(TestMaskedScatterBF16): + def init(self): + self.x_shape = (300, 1) + self.mask_shape = (300, 3) + self.dtype = "uint16" + self.value_shape = (300, 300) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() From f1dcb9cd591b8f3204c1dfa001ea4384e3fc6546 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 1 Dec 2023 13:05:40 +0800 Subject: [PATCH 3/7] fix --- python/paddle/tensor/manipulation.py | 31 +++++++++--------- test/legacy_test/test_masked_scatter.py | 42 ++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index a2d274a36490f..759e1985fb031 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4039,11 +4039,14 @@ def masked_scatter(x, mask, value, name=None): The `value` should have at least as many elements as the number of ones in `mask`. Args: - x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``. - shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1. - The data type is ``int32`` . If ``shape`` is a list or tuple, each element of it should be integer or Tensor with shape []. - If ``shape`` is an Tensor, it should be an 1-D Tensor . - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int32``, + ``int64`` or ``bfloat16``. + mask (Tensor): The boolean tensor indicate the position to be filled. + The data type of mask must be bool. + value (Tensor): The value used to fill the target tensor. + Supported data types are same as x. + name (str, optional): Name for the operation (optional, default is None). For more information, + please refer to :ref:`api_guide_Name`. Returns: Tensor, A reshaped Tensor with the same data type as ``x``. @@ -4062,7 +4065,7 @@ def masked_scatter(x, mask, value, name=None): >>> print(out.shape) """ - # make sure the dtype of x and source is the same + # make sure the dtype of x and value is the same assert ( x.dtype == value.dtype ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' @@ -4076,7 +4079,8 @@ def masked_scatter(x, mask, value, name=None): ), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}' value = value.flatten()[mask_prefix].reshape(mask.shape) - return paddle.where(paddle.cast(mask, dtype="bool"), value, x) + mask = paddle.logical_not(mask) + return paddle.where(mask, x, value) @inplace_apis_in_dygraph_only @@ -4085,20 +4089,13 @@ def masked_scatter_(x, mask, value, name=None): Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_masked_scatter`. """ - # make sure the dtype of x and source is the same assert ( x.dtype == value.dtype ), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}' assert mask.dtype == paddle.bool - shape = paddle.broadcast_shape(x.shape, mask.shape) - if shape != x.shape: - raise ValueError( - "The shape of broadcast output should be equal to inplace tensor shape in the Inplace operation, but got output shape {} and inplace tensor {}.".format( - shape, x.shape - ) - ) - mask = paddle.broadcast_to(mask, shape) - mask_prefix = paddle.clip(mask.astype(int).cumsum() - 1, min=0) + zeros_like_x = paddle.zeros_like(x, dtype=int) + mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x) + mask_prefix = paddle.clip(mask.cumsum() - 1, min=0) assert ( mask_prefix[-1] <= value.numel() ), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}' diff --git a/test/legacy_test/test_masked_scatter.py b/test/legacy_test/test_masked_scatter.py index 199a7bd38bb40..482f756884f0d 100644 --- a/test/legacy_test/test_masked_scatter.py +++ b/test/legacy_test/test_masked_scatter.py @@ -25,7 +25,6 @@ def np_masked_scatter(x, mask, value): x, mask = np.broadcast_arrays(x, mask) mask_prefix_sum = np.clip(mask.cumsum() - 1, a_min=0, a_max=None) - print(mask_prefix_sum, value.size) value = value.flatten()[mask_prefix_sum].reshape(x.shape) return np.where(mask, value, x) @@ -33,6 +32,47 @@ def np_masked_scatter(x, mask, value): paddle.enable_static() +class TestMaskedScatterError(unittest.TestCase): + def setUp(self): + self.init() + + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.mask_np = np.array( + np.random.randint(2, size=self.mask_shape), dtype='bool' + ) + + self.value_np = np.random.randn(*self.value_shape).astype(self.dtype) + + def init(self): + self.x_shape = (50, 3) + self.mask_shape = self.x_shape + self.dtype = "float32" + self.value_shape = (300, 300) + + def test_mask_error(self): + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('int32') + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + + with np.testing.assert_raises(AssertionError): + paddle.masked_scatter(x, mask, value) + + def test_dtype_error(self): + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + value = paddle.to_tensor(self.value_np, dtype='float64') + with np.testing.assert_raises(AssertionError): + paddle.masked_scatter(x, mask, value) + + def test_numel_error(self): + self.value_np = np.random.randn(5, 5).astype(self.dtype) + x = paddle.to_tensor(self.x_np, dtype=self.dtype) + mask = paddle.to_tensor(self.mask_np).astype('bool') + value = paddle.to_tensor(self.value_np, dtype=self.dtype) + with np.testing.assert_raises(AssertionError): + paddle.masked_scatter(x, mask, value) + + class TestMaskedScatterAPI(unittest.TestCase): def setUp(self): self.init() From ef21a94aa4e87feb7590e8c9b021975fbd74b30a Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 5 Dec 2023 14:03:44 +0800 Subject: [PATCH 4/7] fix doc --- python/paddle/tensor/manipulation.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 76f2c635f58c0..d2e009b958d6c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4060,13 +4060,22 @@ def masked_scatter(x, mask, value, name=None): >>> import paddle - >>> x = paddle.randn([3,4]) - >>> mask = paddle.randn([3,4]) + >>> x = paddle.randn([2, 2]) + Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[-0.22560297, -1.48127186], + [-1.12356818, 0.76581174]]) + >>> mask = paddle.randn([2, 2]) >>> mask = mask>0.6 - >>> value = paddle.ones([2,4], dtype="float32") + Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True, + [[False, True ], + [True, False]]) + >>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32") >>> out = paddle.masked_scatter(x, mask, value) - >>> print(out.shape) + >>> print(out) + Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[-0.22560297, 1.], + [2, 0.76581174]]) """ # make sure the dtype of x and value is the same From 43b65f3a26f0d7c06ba6b7469cd088b2b12d1eed Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 5 Dec 2023 17:31:43 +0800 Subject: [PATCH 5/7] fix doc --- python/paddle/tensor/manipulation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index d2e009b958d6c..94f02b601569f 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4059,23 +4059,24 @@ def masked_scatter(x, mask, value, name=None): .. code-block:: python >>> import paddle - + >>> paddle.seed(2048) >>> x = paddle.randn([2, 2]) Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [[-0.22560297, -1.48127186], - [-1.12356818, 0.76581174]]) + [[ 0.74132639, -1.79502666], + [-0.01776697, -0.93422651]]) + >>> mask = paddle.randn([2, 2]) >>> mask = mask>0.6 Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True, - [[False, True ], - [True, False]]) + [[True , True ], + [False, False]]) >>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32") >>> out = paddle.masked_scatter(x, mask, value) >>> print(out) Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [[-0.22560297, 1.], - [2, 0.76581174]]) + [[ 1, 2], + [-0.01776697, -0.93422651]]) """ # make sure the dtype of x and value is the same From 6403dc48e97d24de46a127861f52e0e920b02498 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 8 Dec 2023 15:44:30 +0800 Subject: [PATCH 6/7] fix doc --- python/paddle/tensor/manipulation.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 94f02b601569f..510b4e2f7bdf0 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4061,22 +4061,19 @@ def masked_scatter(x, mask, value, name=None): >>> import paddle >>> paddle.seed(2048) >>> x = paddle.randn([2, 2]) - Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [[ 0.74132639, -1.79502666], - [-0.01776697, -0.93422651]]) - - >>> mask = paddle.randn([2, 2]) - >>> mask = mask>0.6 - Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True, - [[True , True ], - [False, False]]) + >>> print(x) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[-1.24725831, 0.03843464], + [-0.31660911, 0.04793844]]) + + >>> mask = paddle.to_tensor([[True, True], [False, False]]) >>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32") >>> out = paddle.masked_scatter(x, mask, value) >>> print(out) - Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [[ 1, 2], - [-0.01776697, -0.93422651]]) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[1, 2], + [-0.31660911, 0.04793844]]) """ # make sure the dtype of x and value is the same From 6f4485b607947a0162f1c6913ef90dffe06a2f8f Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 12 Dec 2023 13:00:02 +0800 Subject: [PATCH 7/7] fix --- python/paddle/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 0d50f64134c7e..a1f56c6b09c2a 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -284,8 +284,8 @@ index_put_, masked_fill, masked_fill_, - asked_scatter_, - masked_scatte + masked_scatter_, + masked_scatter, moveaxis, put_along_axis, select_scatter, @@ -319,7 +319,7 @@ unstack, view, view_as, - vsplit + vsplit, ) from .tensor.math import ( # noqa: F401 abs,