Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API -part #59973

Merged
merged 14 commits into from
Dec 26, 2023

Conversation

megemini
Copy link
Contributor

@megemini megemini commented Dec 13, 2023

PR types

New features

PR changes

APIs

Description

【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API

RFC:

v1 : PaddlePaddle/community#784
v2 : PaddlePaddle/community#790

p.s. 待 RFC 评审之后完善代码与测试 ~

Copy link

paddle-bot bot commented Dec 13, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@megemini
Copy link
Contributor Author

@zoooo0820 CI 应该没啥问题了 请评审 ~ 谢谢!:)

paddle.enable_static()

RTOL = {'float32': 1e-03, 'float64': 1e-05}
ATOL = {'float32': 1e-03, 'float64': 1e-05}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里设置的两个阈值相对数据类型本身是否宽容度高了些,能否用默认参数测试

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 这里没必要单独设置,已删掉 ~

attrs['dtype'] = dtype

value = value.astype(dtype)
inputs["ValueTensor"] = value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仅静态图分支用到的参数,比如inputs / attrs等,挪到静态图分支下吧


class TestSliceScatterApiFloat32(TestSliceScatterApi):
def init_dtype(self):
self.dtype = 'float32'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也补充下其他数据类型的测试吧,应该不用特别设置atol /rtol? 可以关注下fp16和bf16的单测通过情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加测试用例 ~

另外,由于旧 ir 有些数据类型不支持,所以在测试用例里面单列出来了 ~

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -6749,3 +6749,108 @@ def select_scatter(x, values, axis, index, name=None):
)

return output


def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的默认值好像和下面的参数描述没对上?(startstop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞!🤙🤙🤙

确实不一样,当时也很纠结要怎么写 ~

因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~

参考 python 中 range 的写法 :

https://docs.python.org/3/library/stdtypes.html#range

If the start argument is omitted, it defaults to 0.

有没有什么建议的写法?谢谢!:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为这里如果不写参数的话,如:slice_scatter(x, value) ,那么 start 为 None,程序里面会转换为 0,stop 为 None,会转换为 x.shape[axis] ~

个人建议,可以直接在startstop的参数描述里说清楚,当为 None 时,会出现的情况,且默认值为 None。

def ... 这边还是保留参数。 宗旨就是减少读者误解

sunzhongkai588
sunzhongkai588 previously approved these changes Dec 20, 2023
Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1
Copy link
Contributor

luotao1 commented Dec 20, 2023

可以提供中文文档

Comment on lines 6761 to 6764
axis (int) : the dimension to insert the value. Default is 0.
start (int, optional) : the start index of where to insert. Default is `None` which will be converted to `0`.
stop (int, optional) : the stop index of where to insert. Default is `None` which will be converted to `x.shape[axis]`.
step (int, optional) : the step for each insert. Default is 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because operator of set_valuecan support writing data according to multiple axes, starts, ends and steps, it is better that the parameters of axis, start, stop, step here can support int and list of int ? making this API more user-friendly and powerful

Copy link
Contributor

@jeff41404 jeff41404 Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If modified according to the above suggestions, the parameter naming suggestion are axes, starts, ends and strides, which is consistent with API of paddle.slice and paddle.strided_slice. and rfc should also be modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ~~~

list of int of axis, start, stop and step for slice_scatter NOT make sense ~

For instance:

x is a tensor of shape [8, 8], value is [8, 2], then the axis can ONLY be 1. If axis=0 , the values of value in axis 1 of length 2 can not be fitted into x of length 8. In other words, list of int can only be useful of x and value the same shape, which is not that useful ~~~ BTW, torch not support list of int either ~~~

Copy link
Contributor

@jeff41404 jeff41404 Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a functional perspective, the functionality of this API should be dual to paddle.strided_slice. Because the former is for writing data, while the latter is for reading data, so supporting list of int is more natural, and the underlying operator of set_value also supports it. As for the example you gave, if the shape of value cannot be broadcasted to the shape after the slice, an error should be reported (by set_value). Additionally, not all of design of torch are reasonable, but how to make it better for users is what we pursue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a functional perspective, the functionality of this API should be dual to paddle.strided_slice. Because the former is for writing data, while the latter is for reading data, so supporting list of int is more natural, and the underlying operator of set_value also supports it. As for the example you gave, if the shape of value cannot be broadcasted to the shape after the slice, an error should be reported (by set_value).

Not agree ~

Especially treat slice_scatter as a dual to paddle.strided_slice.

There should be a powerful API set_value or set_value_with_tensor, not slice_scatter ~

We should not bother users to map slice_scatter from torch, mindspore or other framework to ours, and not mess the API up with more options just because we could support it ~

Additionally, not all of design of torch are reasonable, but how to make it better for users is what we pursue

Agree and always !!!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is that regardless of whether API name is set_value or slice_scatter, our goal is to have an API that is functionally dual to paddle.strided_slice. Here, this PR is just a subset of this feature, and we don't need to create two APIs with the same parameters. The only difference is that one supports int and the other supports list of int

Comment on lines 6802 to 6807
exp_shape = [*x_shape[:axis], len(index), *x_shape[axis + 1 :]]
if tuple(exp_shape) != tuple(value_shape):
raise ValueError(
"The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]],"
f"but got value.shape of {value.shape} and slice shape {exp_shape}."
)
Copy link
Contributor

@jeff41404 jeff41404 Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also should support the shape of value can broadcast to the shape of exp_shape, not only exactly equal, and this check will be completed in set_value, so there is no need to add this check here, which actually limits the functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! remove the check and add broadcast test cases ~

Comment on lines 6775 to 6787
>>> x = paddle.zeros((8, 8))
>>> value = paddle.ones((8, 2))
>>> res = paddle.slice_scatter(x, value, axis=1, start=2, stop=6, step=2)
>>> print(res)
Tensor(shape=[8, 8], dtype=float32, place=Place(cpu), stop_gradient=True,
[[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 1., 0., 0., 0.]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the example code needs to add some cases to facilitate user understanding and not be too similar to competitors

@megemini
Copy link
Contributor Author

Update 20231220

  • 删除掉 shape 的 check
  • 增加 broadcast 的单测
  • 增加示例程序

关于 list of intaxis 等参数,个人感觉还是没必要支持,原因:

  • 场景应用很少,没必要,而且会对用户的使用产生困惑 ~
  • torch 和 mindspore 都不支持,应该也算是行业惯例 ~

如何?

@megemini megemini requested a review from jeff41404 December 21, 2023 03:56
@megemini
Copy link
Contributor Author

关于这个 API 是否支持 list of int ,最后说一下我的看法 ~

jeff 大佬认为 slice_scatterstrided_slice对偶 关系,因此需要支持 ~ 这个理由实在无法说服我 ~

  • 首先,我从来不觉得 torch 的设计就是对的,很多 PR 都可以表明我的态度,这里不再赘述。

  • 其次,是否 对偶 ,个人觉得应该从用户的角度出发,而不是我们决定。
    比如说 hsplithstack 虽然参数不同,但我觉得可以是 功能对偶 ,因为从名字上就可以很直观的理解。
    slice_scatterstrided_slice 怎么看都看不出哪儿 对偶 了,至少刚接触这两个 API 的我很难想到 对偶 ... ... 🤣🤣🤣
    如果一定要给 strided_slice 找个 对偶 的名字,我觉得 strided_scatter 更合适,strided 表示这两个 API 是步进的,slice 表示 readscatter 表示 write
    为了 对偶 和极少用到的情况而给用户带来困惑,和抛异常,实在不能赞同 ~~~

  • 最后,如果内部研发一致支持 list of int 的用法,这里改起来不要太简单 ... ... 😆😆😆

不用我单独翻译成英文再写一遍吧 ~ 😜😜😜

@zoooo0820 @jeff41404

@megemini
Copy link
Contributor Author

Update 20231224

  • 修改 axis, start, stop, step 为 axes, starts, ends, strides,与 strided_slice 参数保持一致
  • 修改为 list of int 后,去掉默认参数值
  • 修改并增加测试用例

v2.0 RFC PaddlePaddle/community#790

@jeff41404 @zoooo0820 请评审 ~

@jeff41404
Copy link
Contributor

关于这个 API 是否支持 list of int ,最后说一下我的看法 ~
jeff 大佬认为 slice_scatter 与 strided_slice 是 对偶 关系,因此需要支持 ~ 这个理由实在无法说服我 ~

  • 其次,是否 对偶 ,个人觉得应该从用户的角度出发,而不是我们决定。
    比如说 hsplithstack 虽然参数不同,但我觉得可以是 功能对偶 ,因为从名字上就可以很直观的理解。
    slice_scatterstrided_slice 怎么看都看不出哪儿 对偶 了,至少刚接触这两个 API 的我很难想到 对偶 ... ... 🤣🤣🤣
    如果一定要给 strided_slice 找个 对偶 的名字,我觉得 strided_scatter 更合适,strided 表示这两个 API 是步进的,slice 表示 readscatter 表示 write
    为了 对偶 和极少用到的情况而给用户带来困惑,和抛异常,实在不能赞同 ~~~

The main issue you mentioned here is a naming issue. This DOES exist, the root cause is not that the name of slice_scatter is not good (of course changed to slice_assign may be better). but paddle.strided_slice should actually be paddle.slice (replacing the existing paddle.slice) which will be solved, so the problem you mentioned no longer exists and will not affect the design here.
As for the "rarely used cases", this judgment cannot be made unless there is sufficient data support for slicing multiple axes will be rarely used. Otherwise, we should try to provide rich functionality as much as possible.

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit c52aec7 into PaddlePaddle:develop Dec 26, 2023
@luotao1 luotao1 changed the title 【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API 【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API -part Dec 26, 2023
Wanglongzhi2001 pushed a commit to Wanglongzhi2001/Paddle that referenced this pull request Jan 7, 2024
* [Add] hack5 28 api

* [Change] use set_value op

* [Change] values to value

* [Fix] resolve conflict

* [Update] add test cases

* [Fix] code example

* [Add] dtype test cases

* [Change] start/stop docstring

* [Change] fix start/stop docstring

* [Change] broadcast value to exp_shape

* [Change] axes with list of int

* [Add] as tensor test case

* [Change] code style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants