Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API -part #59973

Merged
merged 14 commits into from
Dec 26, 2023
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@
scatter_nd_add,
shard_index,
slice,
slice_scatter,
split,
squeeze,
squeeze_,
Expand Down Expand Up @@ -623,6 +624,7 @@
'amin',
'any',
'slice',
'slice_scatter',
'normal',
'normal_',
'logsumexp',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
select_scatter,
shard_index,
slice,
slice_scatter,
split,
squeeze,
squeeze_,
Expand Down Expand Up @@ -613,6 +614,7 @@
'scatter_nd',
'shard_index',
'slice',
'slice_scatter',
'split',
'tensor_split',
'hsplit',
Expand Down
105 changes: 104 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
check_variable_and_dtype,
convert_dtype,
)
from ..base.framework import Variable
from ..base.framework import Variable, default_main_program
from ..framework import (
LayerHelper,
_current_expected_place,
Expand Down Expand Up @@ -6749,3 +6749,106 @@ def select_scatter(x, values, axis, index, name=None):
)

return output


def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的默认值好像和下面的参数描述没对上?(startstop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞!🤙🤙🤙

确实不一样,当时也很纠结要怎么写 ~

因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~

参考 python 中 range 的写法 :

https://docs.python.org/3/library/stdtypes.html#range

If the start argument is omitted, it defaults to 0.

有没有什么建议的写法?谢谢!:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~

个人建议,可以直接在startstop的参数描述里说清楚,当为 None 时,会出现的情况,且默认值为 None。

def ... 这边还是保留参数。 宗旨就是减少读者误解

"""
Embeds the value tensor into x at the given axis. Returns a new tensor instead of a view.

Args:
x (Tensor) : The input Tensor. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex128`.
value (Tensor) : The tensor to embed into x. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex128`.
axis (int) : the dimension to insert the value. Default is 0.
start (int, optional) : the start index of where to insert. Default is 0.
stop (int, optional) : the stop index of where to insert. Default is x.shape[axis].
step (int, optional) : the step for each insert. Default is 1.
name (str, optional): Name for the operation (optional, default is None).

Returns:
Tensor, same dtype and shape with x

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.zeros((8, 8))
>>> value = paddle.ones((8, 2))
>>> res = paddle.slice_scatter(x, value, axis=1, start=2, stop=6, step=2)
>>> print(res)
Tensor(shape=[8, 8], dtype=float32, place=Place(cpu), stop_gradient=True,
[[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the example code needs to add some cases to facilitate user understanding and not be too similar to competitors


"""
if x.ndim != value.ndim:
raise ValueError(
f"The input x and value should have save dimension, but got input of {x.ndim} and value of {value.ndim}."
)

x_shape = x.shape
value_shape = value.shape

start = 0 if start is None else start
stop = x_shape[axis] if stop is None else stop

index = list(range(start, stop, step))
exp_shape = [*x_shape[:axis], len(index), *x_shape[axis + 1 :]]
if tuple(exp_shape) != tuple(value_shape):
raise ValueError(
"The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]],"
f"but got value.shape of {value.shape} and slice shape {exp_shape}."
)
Copy link
Contributor

@jeff41404 jeff41404 Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also should support the shape of value can broadcast to the shape of exp_shape, not only exactly equal, and this check will be completed in set_value, so there is no need to add this check here, which actually limits the functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! remove the check and add broadcast test cases ~


starts = [start]
ends = [stop]
steps = [step]
axes = [axis]
none_axes = []
decrease_axes = []
inputs = {'Input': x}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

dtype = x.dtype
attrs['dtype'] = dtype

value = value.astype(dtype)
inputs["ValueTensor"] = value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仅静态图分支用到的参数,比如inputs / attrs等,挪到静态图分支下吧


if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
)
else:
helper = LayerHelper('slice_scatter', **locals())
output = helper.create_variable_for_type_inference(dtype=x.dtype)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

return output
226 changes: 226 additions & 0 deletions test/legacy_test/test_slice_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.framework import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

RTOL = {'float32': 1e-03, 'float64': 1e-05}
ATOL = {'float32': 1e-03, 'float64': 1e-05}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里设置的两个阈值相对数据类型本身是否宽容度高了些,能否用默认参数测试

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 这里没必要单独设置,已删掉 ~



def numpy_ref(x, value, axis=0, start=None, stop=None, step=1):
_x = np.copy(x)

start = 0 if start is None else start
stop = _x.shape[axis] if stop is None else stop

index = range(start, stop, step)
exp_shape = [
*([1] * _x.ndim)[:axis],
len(index),
*([1] * _x.ndim)[axis + 1 :],
]

np.put_along_axis(
_x, np.arange(start, stop, step).reshape(exp_shape), value, axis=axis
)

return _x


class TestSliceScatterApi(unittest.TestCase):
def setUp(self):
np.random.seed(2023)

self.init_dtype()
self.init_shape()

self.x_np = np.random.random(self.x_shape).astype(self.dtype)
self.value_np = np.random.random(self.value_shape).astype(self.dtype)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def init_dtype(self):
self.dtype = 'float64'

def init_shape(self):
self.x_shape = [8, 6]
self.value_shape = [8, 2]
self.axis = 1
self.start = 2
self.stop = 6
self.step = 2

@test_with_pir_api
def test_api_static(self):
paddle.enable_static()

for place in self.place:
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('x', self.x_shape, self.dtype)
value = paddle.static.data(
'value', self.value_shape, self.dtype
)

out = paddle.slice_scatter(
x,
value,
axis=self.axis,
start=self.start,
stop=self.stop,
step=self.step,
)
exe = paddle.static.Executor(place)
res = exe.run(
feed={
'x': self.x_np,
'value': self.value_np,
},
fetch_list=[out],
)[0]

out_ref = numpy_ref(
self.x_np,
self.value_np,
axis=self.axis,
start=self.start,
stop=self.stop,
step=self.step,
)

np.testing.assert_allclose(
res, out_ref, rtol=RTOL[self.dtype], atol=ATOL[self.dtype]
)

def test_api_dygraph(self):
for place in self.place:
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.slice_scatter(
x_tensor,
value_tensor,
axis=self.axis,
start=self.start,
stop=self.stop,
step=self.step,
)
out_ref = numpy_ref(
self.x_np,
self.value_np,
axis=self.axis,
start=self.start,
stop=self.stop,
step=self.step,
)

np.testing.assert_allclose(
out.numpy(),
out_ref,
rtol=RTOL[self.dtype],
atol=ATOL[self.dtype],
)

paddle.enable_static()


class TestSliceScatterApiFloat32(TestSliceScatterApi):
def init_dtype(self):
self.dtype = 'float32'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也补充下其他数据类型的测试吧,应该不用特别设置atol /rtol? 可以关注下fp16和bf16的单测通过情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加测试用例 ~

另外,由于旧 ir 有些数据类型不支持,所以在测试用例里面单列出来了 ~



class TestSliceScatterApiNoneStartStop(TestSliceScatterApi):
def init_shape(self):
self.x_shape = [6, 8]
self.value_shape = [2, 8]
self.axis = 0
self.start = None
self.stop = None
self.step = 3


class TestSliceScatterApi3D(TestSliceScatterApi):
def init_shape(self):
self.x_shape = [8, 6, 3]
self.value_shape = [8, 2, 3]
self.axis = 1
self.start = 2
self.stop = 6
self.step = 2


class TestSliceScatterApi3DFloat32(TestSliceScatterApi3D):
def init_dtype(self):
self.dtype = 'float32'


class TestSliceScatterApi4D(TestSliceScatterApi):
def init_shape(self):
self.x_shape = [8, 6, 3, 5]
self.value_shape = [8, 2, 3, 5]
self.axis = 1
self.start = 2
self.stop = 6
self.step = 2


class TestSliceScatterApi4DFloat32(TestSliceScatterApi4D):
def init_dtype(self):
self.dtype = 'float32'


class TestSliceScatterApi4DAxis3(TestSliceScatterApi):
def init_shape(self):
self.x_shape = [8, 6, 3, 9]
self.value_shape = [8, 6, 3, 2]
self.axis = 3
self.start = 2
self.stop = 6
self.step = 2


class TestSliceScatterApi4DAxis3Float32(TestSliceScatterApi4DAxis3):
def init_dtype(self):
self.dtype = 'float32'


class TestSliceScatterApiError(unittest.TestCase):
def test_error_ndim(self):
with self.assertRaises(ValueError):
x = np.random.rand(8, 6, 3)
value = np.random.rand(8, 3)
_ = paddle.slice_scatter(x, value)

def test_error_index(self):
with self.assertRaises(ValueError):
x = np.random.rand(8, 6)
value = np.random.rand(8, 3)
_ = paddle.slice_scatter(x, value, axis=1, step=1)

with self.assertRaises(ValueError):
x = np.random.rand(8, 6)
value = np.random.rand(2, 6)
_ = paddle.slice_scatter(x, value)


if __name__ == '__main__':
unittest.main()