-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.4】为 Paddle 新增 masked_scatter API -part #59383
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/tensor/manipulation.py
Outdated
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``. | ||
Please refer to :ref:`api_paddle_masked_scatter`. | ||
""" | ||
# make sure the dtype of x and source is the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不必要的注释可移除掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self.dtype = "uint16" | ||
self.value_shape = (300, 300) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实现代码中的几个有检查的错误场景,辛苦也补充下单测吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
value = value.flatten()[mask_prefix].reshape(mask.shape) | ||
return paddle.where(paddle.cast(mask, dtype="bool"), value, x) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实现逻辑上,masked_scatter
和 masked_scatter_
两个API有不同的实现逻辑。是某些操作在静态图下不支持,又希望masked_scatter
使用同一套代码的原因吗。
理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改为两个 API 的逻辑一致,性能只有较为轻微的差异
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
python/paddle/tensor/manipulation.py
Outdated
>>> value = paddle.ones([2,4], dtype="float32") | ||
|
||
>>> out = paddle.masked_scatter(x, mask, value) | ||
>>> print(out.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's best to give the values of x
, mask
and out
in example code to make it easier for users to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
>>> import paddle | ||
|
||
>>> x = paddle.randn([2, 2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
估计得加 seed
固定输出,否则示例检查过不了,参考 固定的输出优于随机
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed 固定的情况下 CI 和 本地的结果不同,我这里又重新改了一次。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
可以提交对应的中文文档 |
|
c6faf4a
to
6403dc4
Compare
|
我再 rerun 下试试,我没有修改过 test_masked_scatter 的文件,之前是通过 CI 的。 |
test_masked_scatter 可能存在随机挂 https://xly.bce.baidu.com/paddlepaddle/paddle/newipipe/builds/6512?module=github/PaddlePaddle/Paddle&pipeline=PR-CI-Coverage&branch=pull/59383(develop) 看一下历史记录,有超过一半是单测失败的 |
a47cc43
to
6f4485b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
请提交中文文档 |
@yangguohao 你好,因为这个PR添加的单测中暴露了些问题,能辛苦再花些时间跟进下吗?
问题2的复现步骤可以拷贝一份单测中 如果定位到问题2和其他API或者框架机制有关,可以反馈出来,我们再看看怎么解决该问题。 |
您好,我这里还有几个问题: 问题1,我尝试了一下没有什么很好的方法解决,在静态图下只能跳过这个检查。或者有什么建议可以完成这个功能 问题2我复现了错误,在固定了seed尝试了不同的shape后,还是没办法定位问题所在。最后尝试在 logical_not 的时候将 mask 先转为 bool dtype 问题不再出现。之后会提相关的 pr进行修改。 |
PR types
New features
PR changes
APIs
Description
add masked_scatter for Paddle