Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383

Merged
merged 10 commits into from
Dec 13, 2023

Conversation

yangguohao
Copy link
Contributor

PR types

New features

PR changes

APIs

Description

add masked_scatter for Paddle

Copy link

paddle-bot bot commented Nov 27, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_scatter`.
"""
# make sure the dtype of x and source is the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不必要的注释可移除掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.dtype = "uint16"
self.value_shape = (300, 300)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现代码中的几个有检查的错误场景,辛苦也补充下单测吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


value = value.flatten()[mask_prefix].reshape(mask.shape)
return paddle.where(paddle.cast(mask, dtype="bool"), value, x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现逻辑上,masked_scatter masked_scatter_两个API有不同的实现逻辑。是某些操作在静态图下不支持,又希望masked_scatter 使用同一套代码的原因吗。

理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为两个 API 的逻辑一致,性能只有较为轻微的差异

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

>>> value = paddle.ones([2,4], dtype="float32")

>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to give the values of x, mask and out in example code to make it easier for users to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


>>> import paddle

>>> x = paddle.randn([2, 2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

估计得加 seed 固定输出,否则示例检查过不了,参考 固定的输出优于随机

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例代码检查没过,log 如下,可能是因为 x = paddle.randn([2, 2]) 没有 print(x),导致没有获得输出。代码示例部分也都注意一下这吧~ 可以本地跑一下试试
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed 固定的情况下 CI 和 本地的结果不同,我这里又重新改了一次。

jeff41404
jeff41404 previously approved these changes Dec 6, 2023
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1
Copy link
Contributor

luotao1 commented Dec 6, 2023

可以提交对应的中文文档

@luotao1
Copy link
Contributor

luotao1 commented Dec 8, 2023

2023-12-06 01:49:14 ++ python sampcd_processor.py --debug --mode cpu
2023-12-06 01:49:14 ----------------Check results--------------------
2023-12-06 01:49:14 >>> Sample code test capacity: {'cpu'}
2023-12-06 01:49:14 >>> 1 sample codes ran failed in env: {'cpu'}
2023-12-06 01:49:14 <DocTest(<modname?> paddle.Tensor.masked_scatter:1:0 ln 1)>, running time: 1.003s
2023-12-06 01:49:14 >>> Mistakes found in sample codes in env: {'cpu'}!
2023-12-06 01:49:14 >>> Please recheck the sample codes.

@luotao1
Copy link
Contributor

luotao1 commented Dec 11, 2023

2023-12-11 13:13:17 The following tests FAILED:
2023-12-11 13:13:17 	1916 - test_masked_scatter (Failed)

@yangguohao
Copy link
Contributor Author

2023-12-11 13:13:17 The following tests FAILED:
2023-12-11 13:13:17 	1916 - test_masked_scatter (Failed)

我再 rerun 下试试,我没有修改过 test_masked_scatter 的文件,之前是通过 CI 的。

@luotao1
Copy link
Contributor

luotao1 commented Dec 11, 2023

test_masked_scatter 可能存在随机挂 https://xly.bce.baidu.com/paddlepaddle/paddle/newipipe/builds/6512?module=github/PaddlePaddle/Paddle&pipeline=PR-CI-Coverage&branch=pull/59383(develop) 看一下历史记录,有超过一半是单测失败的

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1
Copy link
Contributor

luotao1 commented Dec 13, 2023

请提交中文文档

@luotao1 luotao1 changed the title 【Hackathon No.4】为 Paddle 新增 masked_scatter API 【Hackathon No.4】为 Paddle 新增 masked_scatter API -part Dec 13, 2023
@luotao1 luotao1 merged commit 8d717f3 into PaddlePaddle:develop Dec 13, 2023
@zoooo0820
Copy link
Contributor

test_masked_scatter 可能存在随机挂 https://xly.bce.baidu.com/paddlepaddle/paddle/newipipe/builds/6512?module=github/PaddlePaddle/Paddle&pipeline=PR-CI-Coverage&branch=pull/59383(develop) 看一下历史记录,有超过一半是单测失败的

@yangguohao 你好,因为这个PR添加的单测中暴露了些问题,能辛苦再花些时间跟进下吗?

  • 问题1:test_numel_error 不能在静态图下通过
    这个和之前讨论的一样,静态图下不能通过tensor的值去做assert(返回是个tensor而不是具体的true/false),需要修改下现有的函数逻辑

  • 问题2:部分单测存在随机挂,如TestMaskedScatterAPIBroadcast3
    这个问题的原因尚不清晰,辛苦再排查一下,目前看起来只有静态图分支会出现问题。初步看起来,通过移除对mask的broadcast逻辑似乎可以解决问题,但可能影响后续处理value的逻辑。

    zeros_like_x = paddle.zeros_like(x, dtype=int)
    mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
    mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)

问题2的复现步骤可以拷贝一份单测中test_static_graph的代码,配置复用TestMaskedScatterAPIBroadcast3,该问题看起来和np初始化的值无关,可以设定特定的x和mask值方便定位,循环执行1000次。

如果定位到问题2和其他API或者框架机制有关,可以反馈出来,我们再看看怎么解决该问题。

@yangguohao
Copy link
Contributor Author

test_masked_scatter 可能存在随机挂 https://xly.bce.baidu.com/paddlepaddle/paddle/newipipe/builds/6512?module=github/PaddlePaddle/Paddle&pipeline=PR-CI-Coverage&branch=pull/59383(develop) 看一下历史记录,有超过一半是单测失败的

@yangguohao 你好,因为这个PR添加的单测中暴露了些问题,能辛苦再花些时间跟进下吗?

  • 问题1:test_numel_error 不能在静态图下通过
    这个和之前讨论的一样,静态图下不能通过tensor的值去做assert(返回是个tensor而不是具体的true/false),需要修改下现有的函数逻辑
  • 问题2:部分单测存在随机挂,如TestMaskedScatterAPIBroadcast3
    这个问题的原因尚不清晰,辛苦再排查一下,目前看起来只有静态图分支会出现问题。初步看起来,通过移除对mask的broadcast逻辑似乎可以解决问题,但可能影响后续处理value的逻辑。
    zeros_like_x = paddle.zeros_like(x, dtype=int)
    mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
    mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)

问题2的复现步骤可以拷贝一份单测中test_static_graph的代码,配置复用TestMaskedScatterAPIBroadcast3,该问题看起来和np初始化的值无关,可以设定特定的x和mask值方便定位,循环执行1000次。

如果定位到问题2和其他API或者框架机制有关,可以反馈出来,我们再看看怎么解决该问题。

您好,我这里还有几个问题:

问题1,我尝试了一下没有什么很好的方法解决,在静态图下只能跳过这个检查。或者有什么建议可以完成这个功能

问题2我复现了错误,在固定了seed尝试了不同的shape后,还是没办法定位问题所在。最后尝试在 logical_not 的时候将 mask 先转为 bool dtype 问题不再出现。之后会提相关的 pr进行修改。

@yangguohao yangguohao mentioned this pull request Jan 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants