-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API -part #59973
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@zoooo0820 CI 应该没啥问题了 请评审 ~ 谢谢!:) |
paddle.enable_static() | ||
|
||
RTOL = {'float32': 1e-03, 'float64': 1e-05} | ||
ATOL = {'float32': 1e-03, 'float64': 1e-05} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里设置的两个阈值相对数据类型本身是否宽容度高了些,能否用默认参数测试
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯 这里没必要单独设置,已删掉 ~
python/paddle/tensor/manipulation.py
Outdated
attrs['dtype'] = dtype | ||
|
||
value = value.astype(dtype) | ||
inputs["ValueTensor"] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
仅静态图分支用到的参数,比如inputs / attrs等,挪到静态图分支下吧
|
||
class TestSliceScatterApiFloat32(TestSliceScatterApi): | ||
def init_dtype(self): | ||
self.dtype = 'float32' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也补充下其他数据类型的测试吧,应该不用特别设置atol /rtol? 可以关注下fp16和bf16的单测通过情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加测试用例 ~
另外,由于旧 ir 有些数据类型不支持,所以在测试用例里面单列出来了 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/tensor/manipulation.py
Outdated
@@ -6749,3 +6749,108 @@ def select_scatter(x, values, axis, index, name=None): | |||
) | |||
|
|||
return output | |||
|
|||
|
|||
def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的默认值好像和下面的参数描述没对上?(start
和 stop
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
赞!🤙🤙🤙
确实不一样,当时也很纠结要怎么写 ~
因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~
参考 python 中 range 的写法 :
https://docs.python.org/3/library/stdtypes.html#range
If the start argument is omitted, it defaults to 0.
有没有什么建议的写法?谢谢!:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~
个人建议,可以直接在start
和 stop
的参数描述里说清楚,当为 None 时,会出现的情况,且默认值为 None。
def ...
这边还是保留参数。 宗旨就是减少读者误解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
可以提供中文文档 |
python/paddle/tensor/manipulation.py
Outdated
axis (int) : the dimension to insert the value. Default is 0. | ||
start (int, optional) : the start index of where to insert. Default is `None` which will be converted to `0`. | ||
stop (int, optional) : the stop index of where to insert. Default is `None` which will be converted to `x.shape[axis]`. | ||
step (int, optional) : the step for each insert. Default is 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because operator of set_value
can support writing data according to multiple axes
, starts
, ends
and steps
, it is better that the parameters of axis
, start
, stop
, step
here can support int
and list of int
? making this API more user-friendly and powerful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If modified according to the above suggestions, the parameter naming suggestion are axes
, starts
, ends
and strides
, which is consistent with API of paddle.slice
and paddle.strided_slice
. and rfc should also be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No ~~~
list of int
of axis
, start
, stop
and step
for slice_scatter
NOT make sense ~
For instance:
x
is a tensor of shape [8, 8]
, value
is [8, 2]
, then the axis
can ONLY be 1
. If axis=0
, the values of value
in axis 1 of length 2 can not be fitted into x
of length 8. In other words, list of int
can only be useful of x
and value
the same shape, which is not that useful ~~~ BTW, torch not support list of int
either ~~~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a functional perspective, the functionality of this API should be dual to paddle.strided_slice
. Because the former is for writing data, while the latter is for reading data, so supporting list of int
is more natural, and the underlying operator of set_value
also supports it. As for the example you gave, if the shape of value
cannot be broadcasted to the shape after the slice, an error should be reported (by set_value
). Additionally, not all of design of torch are reasonable, but how to make it better for users is what we pursue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a functional perspective, the functionality of this API should be dual to
paddle.strided_slice
. Because the former is for writing data, while the latter is for reading data, so supportinglist of int
is more natural, and the underlying operator ofset_value
also supports it. As for the example you gave, if the shape ofvalue
cannot be broadcasted to the shape after the slice, an error should be reported (byset_value
).
Not agree ~
Especially treat slice_scatter
as a dual to paddle.strided_slice
.
There should be a powerful API set_value
or set_value_with_tensor
, not slice_scatter
~
We should not bother users to map slice_scatter
from torch, mindspore or other framework to ours, and not mess the API up with more options just because we could support it ~
Additionally, not all of design of torch are reasonable, but how to make it better for users is what we pursue
Agree and always !!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My opinion is that regardless of whether API name is set_value
or slice_scatter
, our goal is to have an API that is functionally dual to paddle.strided_slice
. Here, this PR is just a subset of this feature, and we don't need to create two APIs with the same parameters. The only difference is that one supports int
and the other supports list of int
python/paddle/tensor/manipulation.py
Outdated
exp_shape = [*x_shape[:axis], len(index), *x_shape[axis + 1 :]] | ||
if tuple(exp_shape) != tuple(value_shape): | ||
raise ValueError( | ||
"The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]]," | ||
f"but got value.shape of {value.shape} and slice shape {exp_shape}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also should support the shape of value
can broadcast to the shape of exp_shape
, not only exactly equal, and this check will be completed in set_value
, so there is no need to add this check here, which actually limits the functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK! remove the check and add broadcast
test cases ~
python/paddle/tensor/manipulation.py
Outdated
>>> x = paddle.zeros((8, 8)) | ||
>>> value = paddle.ones((8, 2)) | ||
>>> res = paddle.slice_scatter(x, value, axis=1, start=2, stop=6, step=2) | ||
>>> print(res) | ||
Tensor(shape=[8, 8], dtype=float32, place=Place(cpu), stop_gradient=True, | ||
[[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.], | ||
[0., 0., 1., 0., 1., 0., 0., 0.]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the example code needs to add some cases to facilitate user understanding and not be too similar to competitors
Update 20231220
关于
如何? |
关于这个 API 是否支持 jeff 大佬认为
不用我单独翻译成英文再写一遍吧 ~ 😜😜😜 |
Update 20231224
v2.0 RFC PaddlePaddle/community#790 @jeff41404 @zoooo0820 请评审 ~ |
The main issue you mentioned here is a naming issue. This DOES exist, the root cause is not that the name of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* [Add] hack5 28 api * [Change] use set_value op * [Change] values to value * [Fix] resolve conflict * [Update] add test cases * [Fix] code example * [Add] dtype test cases * [Change] start/stop docstring * [Change] fix start/stop docstring * [Change] broadcast value to exp_shape * [Change] axes with list of int * [Add] as tensor test case * [Change] code style
PR types
New features
PR changes
APIs
Description
【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API
RFC:
v1 : PaddlePaddle/community#784
v2 : PaddlePaddle/community#790
p.s. 待 RFC 评审之后完善代码与测试 ~