-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API -part #59973
Changes from 7 commits
e85068e
009fd46
9a1ad5e
53dc749
f729a4e
b5371ac
5d82cbd
1fcf498
e89191a
7c3de98
20e1600
e5fecec
8e03291
6b768da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ | |
check_variable_and_dtype, | ||
convert_dtype, | ||
) | ||
from ..base.framework import Variable | ||
from ..base.framework import Variable, default_main_program | ||
from ..framework import ( | ||
LayerHelper, | ||
_current_expected_place, | ||
|
@@ -6749,3 +6749,106 @@ def select_scatter(x, values, axis, index, name=None): | |
) | ||
|
||
return output | ||
|
||
|
||
def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None): | ||
""" | ||
Embeds the value tensor into x at the given axis. Returns a new tensor instead of a view. | ||
|
||
Args: | ||
x (Tensor) : The input Tensor. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex128`. | ||
value (Tensor) : The tensor to embed into x. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex128`. | ||
axis (int) : the dimension to insert the value. Default is 0. | ||
start (int, optional) : the start index of where to insert. Default is 0. | ||
stop (int, optional) : the stop index of where to insert. Default is x.shape[axis]. | ||
step (int, optional) : the step for each insert. Default is 1. | ||
name (str, optional): Name for the operation (optional, default is None). | ||
|
||
Returns: | ||
Tensor, same dtype and shape with x | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> x = paddle.zeros((8, 8)) | ||
>>> value = paddle.ones((8, 2)) | ||
>>> res = paddle.slice_scatter(x, value, axis=1, start=2, stop=6, step=2) | ||
>>> print(res) | ||
Tensor(shape=[8, 8], dtype=float32, place=Place(cpu), stop_gradient=True, | ||
[[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the example code needs to add some cases to facilitate user understanding and not be too similar to competitors |
||
|
||
""" | ||
if x.ndim != value.ndim: | ||
raise ValueError( | ||
f"The input x and value should have save dimension, but got input of {x.ndim} and value of {value.ndim}." | ||
) | ||
|
||
x_shape = x.shape | ||
value_shape = value.shape | ||
|
||
start = 0 if start is None else start | ||
stop = x_shape[axis] if stop is None else stop | ||
|
||
index = list(range(start, stop, step)) | ||
exp_shape = [*x_shape[:axis], len(index), *x_shape[axis + 1 :]] | ||
if tuple(exp_shape) != tuple(value_shape): | ||
raise ValueError( | ||
"The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]]," | ||
f"but got value.shape of {value.shape} and slice shape {exp_shape}." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also should support the shape of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK! remove the check and add |
||
|
||
starts = [start] | ||
ends = [stop] | ||
steps = [step] | ||
axes = [axis] | ||
none_axes = [] | ||
decrease_axes = [] | ||
inputs = {'Input': x} | ||
attrs = { | ||
'axes': axes, | ||
'starts': starts, | ||
'ends': ends, | ||
'steps': steps, | ||
'decrease_axes': decrease_axes, | ||
'none_axes': none_axes, | ||
} | ||
|
||
dtype = x.dtype | ||
attrs['dtype'] = dtype | ||
|
||
value = value.astype(dtype) | ||
inputs["ValueTensor"] = value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 仅静态图分支用到的参数,比如inputs / attrs等,挪到静态图分支下吧 |
||
|
||
if in_dynamic_or_pir_mode(): | ||
return _C_ops.set_value_with_tensor( | ||
x, | ||
value, | ||
starts, | ||
ends, | ||
steps, | ||
axes, | ||
decrease_axes, | ||
none_axes, | ||
) | ||
else: | ||
helper = LayerHelper('slice_scatter', **locals()) | ||
output = helper.create_variable_for_type_inference(dtype=x.dtype) | ||
cur_block = default_main_program().current_block() | ||
cur_block.append_op( | ||
type="set_value", | ||
inputs=inputs, | ||
outputs={'Out': output}, | ||
attrs=attrs, | ||
inplace_map={"Input": "Out"}, | ||
) | ||
|
||
return output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
from paddle.framework import core | ||
from paddle.pir_utils import test_with_pir_api | ||
|
||
paddle.enable_static() | ||
|
||
RTOL = {'float32': 1e-03, 'float64': 1e-05} | ||
ATOL = {'float32': 1e-03, 'float64': 1e-05} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里设置的两个阈值相对数据类型本身是否宽容度高了些,能否用默认参数测试 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯 这里没必要单独设置,已删掉 ~ |
||
|
||
|
||
def numpy_ref(x, value, axis=0, start=None, stop=None, step=1): | ||
_x = np.copy(x) | ||
|
||
start = 0 if start is None else start | ||
stop = _x.shape[axis] if stop is None else stop | ||
|
||
index = range(start, stop, step) | ||
exp_shape = [ | ||
*([1] * _x.ndim)[:axis], | ||
len(index), | ||
*([1] * _x.ndim)[axis + 1 :], | ||
] | ||
|
||
np.put_along_axis( | ||
_x, np.arange(start, stop, step).reshape(exp_shape), value, axis=axis | ||
) | ||
|
||
return _x | ||
|
||
|
||
class TestSliceScatterApi(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
|
||
self.init_dtype() | ||
self.init_shape() | ||
|
||
self.x_np = np.random.random(self.x_shape).astype(self.dtype) | ||
self.value_np = np.random.random(self.value_shape).astype(self.dtype) | ||
self.place = [paddle.CPUPlace()] | ||
if core.is_compiled_with_cuda(): | ||
self.place.append(paddle.CUDAPlace(0)) | ||
|
||
def init_dtype(self): | ||
self.dtype = 'float64' | ||
|
||
def init_shape(self): | ||
self.x_shape = [8, 6] | ||
self.value_shape = [8, 2] | ||
self.axis = 1 | ||
self.start = 2 | ||
self.stop = 6 | ||
self.step = 2 | ||
|
||
@test_with_pir_api | ||
def test_api_static(self): | ||
paddle.enable_static() | ||
|
||
for place in self.place: | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.static.data('x', self.x_shape, self.dtype) | ||
value = paddle.static.data( | ||
'value', self.value_shape, self.dtype | ||
) | ||
|
||
out = paddle.slice_scatter( | ||
x, | ||
value, | ||
axis=self.axis, | ||
start=self.start, | ||
stop=self.stop, | ||
step=self.step, | ||
) | ||
exe = paddle.static.Executor(place) | ||
res = exe.run( | ||
feed={ | ||
'x': self.x_np, | ||
'value': self.value_np, | ||
}, | ||
fetch_list=[out], | ||
)[0] | ||
|
||
out_ref = numpy_ref( | ||
self.x_np, | ||
self.value_np, | ||
axis=self.axis, | ||
start=self.start, | ||
stop=self.stop, | ||
step=self.step, | ||
) | ||
|
||
np.testing.assert_allclose( | ||
res, out_ref, rtol=RTOL[self.dtype], atol=ATOL[self.dtype] | ||
) | ||
|
||
def test_api_dygraph(self): | ||
for place in self.place: | ||
paddle.disable_static(place) | ||
x_tensor = paddle.to_tensor(self.x_np) | ||
value_tensor = paddle.to_tensor(self.value_np) | ||
out = paddle.slice_scatter( | ||
x_tensor, | ||
value_tensor, | ||
axis=self.axis, | ||
start=self.start, | ||
stop=self.stop, | ||
step=self.step, | ||
) | ||
out_ref = numpy_ref( | ||
self.x_np, | ||
self.value_np, | ||
axis=self.axis, | ||
start=self.start, | ||
stop=self.stop, | ||
step=self.step, | ||
) | ||
|
||
np.testing.assert_allclose( | ||
out.numpy(), | ||
out_ref, | ||
rtol=RTOL[self.dtype], | ||
atol=ATOL[self.dtype], | ||
) | ||
|
||
paddle.enable_static() | ||
|
||
|
||
class TestSliceScatterApiFloat32(TestSliceScatterApi): | ||
def init_dtype(self): | ||
self.dtype = 'float32' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 也补充下其他数据类型的测试吧,应该不用特别设置atol /rtol? 可以关注下fp16和bf16的单测通过情况 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已增加测试用例 ~ 另外,由于旧 ir 有些数据类型不支持,所以在测试用例里面单列出来了 ~ |
||
|
||
|
||
class TestSliceScatterApiNoneStartStop(TestSliceScatterApi): | ||
def init_shape(self): | ||
self.x_shape = [6, 8] | ||
self.value_shape = [2, 8] | ||
self.axis = 0 | ||
self.start = None | ||
self.stop = None | ||
self.step = 3 | ||
|
||
|
||
class TestSliceScatterApi3D(TestSliceScatterApi): | ||
def init_shape(self): | ||
self.x_shape = [8, 6, 3] | ||
self.value_shape = [8, 2, 3] | ||
self.axis = 1 | ||
self.start = 2 | ||
self.stop = 6 | ||
self.step = 2 | ||
|
||
|
||
class TestSliceScatterApi3DFloat32(TestSliceScatterApi3D): | ||
def init_dtype(self): | ||
self.dtype = 'float32' | ||
|
||
|
||
class TestSliceScatterApi4D(TestSliceScatterApi): | ||
def init_shape(self): | ||
self.x_shape = [8, 6, 3, 5] | ||
self.value_shape = [8, 2, 3, 5] | ||
self.axis = 1 | ||
self.start = 2 | ||
self.stop = 6 | ||
self.step = 2 | ||
|
||
|
||
class TestSliceScatterApi4DFloat32(TestSliceScatterApi4D): | ||
def init_dtype(self): | ||
self.dtype = 'float32' | ||
|
||
|
||
class TestSliceScatterApi4DAxis3(TestSliceScatterApi): | ||
def init_shape(self): | ||
self.x_shape = [8, 6, 3, 9] | ||
self.value_shape = [8, 6, 3, 2] | ||
self.axis = 3 | ||
self.start = 2 | ||
self.stop = 6 | ||
self.step = 2 | ||
|
||
|
||
class TestSliceScatterApi4DAxis3Float32(TestSliceScatterApi4DAxis3): | ||
def init_dtype(self): | ||
self.dtype = 'float32' | ||
|
||
|
||
class TestSliceScatterApiError(unittest.TestCase): | ||
def test_error_ndim(self): | ||
with self.assertRaises(ValueError): | ||
x = np.random.rand(8, 6, 3) | ||
value = np.random.rand(8, 3) | ||
_ = paddle.slice_scatter(x, value) | ||
|
||
def test_error_index(self): | ||
with self.assertRaises(ValueError): | ||
x = np.random.rand(8, 6) | ||
value = np.random.rand(8, 3) | ||
_ = paddle.slice_scatter(x, value, axis=1, step=1) | ||
|
||
with self.assertRaises(ValueError): | ||
x = np.random.rand(8, 6) | ||
value = np.random.rand(2, 6) | ||
_ = paddle.slice_scatter(x, value) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的默认值好像和下面的参数描述没对上?(
start
和stop
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
赞!🤙🤙🤙
确实不一样,当时也很纠结要怎么写 ~
因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~
参考 python 中 range 的写法 :
https://docs.python.org/3/library/stdtypes.html#range
有没有什么建议的写法?谢谢!:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
个人建议,可以直接在
start
和stop
的参数描述里说清楚,当为 None 时,会出现的情况,且默认值为 None。def ...
这边还是保留参数。 宗旨就是减少读者误解