Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.24】add RFC for ChannelShuffle #40

Merged
merged 4 commits into from
Mar 23, 2022

Conversation

BrilliantYuKaimin
Copy link
Contributor

增加paddle.nn.ChannelShuffle设计文档。

@paddle-bot-old
Copy link

Thanks for your contribution!

@dingjiaweiww
Copy link
Contributor

你的 PR 提交成功,感谢你对于开源项目的贡献,请检查 PR 提交格式和内容是否完备,具体请参考示例模版

@dingjiaweiww
Copy link
Contributor

PR 格式检查通过,你的PR 将接受Paddle 专家以及开源社区的review,请及时关注PR 动态

Copy link
Contributor

@shiyutang shiyutang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

方案整体完备,但是实现细节等需要步骤化说明,请修改方案后再提交代码。修改示例参照:https://github.com/PaddlePaddle/community/blob/master/rfcs/APIs/20200301_api_design_for_quantile.md

Comment on lines 25 to 26
飞桨目前不支持此功能,但可以通过组合API的方式实现此功能。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请展开描述组合API进行的方式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

Comment on lines +46 to +74
Tensor math_channel_shuffle(const Tensor& self, int64_t groups) {
int64_t b = self.size(0);
int64_t c = self.size(1);
int64_t oc = c / groups;

auto input_reshaped = self.view({b, groups, oc, -1});
Tensor output_tensor =
input_reshaped.permute({0 /* b */, 2 /* oc */, 1 /* groups */, 3})
.contiguous()
.reshape(self.sizes());
return namedinference::propagate_names_if_nonempty(
output_tensor,
self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
}
```

## TensorFlow

TensorFlow目前没有直接提供`ChannelShuffle`的API,但是也有[网友](https://blog.csdn.net/baidu_23388287/article/details/94456951)通过组合API的方式实现了该操作:

```python
def shuffle_unit(self, x, groups):
with tf.variable_scope('shuffle_unit'):
n, h, w, c = x.get_shape().as_list()
x = tf.reshape(x, shape=tf.convert_to_tensor([tf.shape(x)[0], h, w, groups, c // groups]))
x = tf.transpose(x, tf.convert_to_tensor([0, 1, 2, 4, 3]))
x = tf.reshape(x, shape=tf.convert_to_tensor([tf.shape(x)[0], h, w, c]))
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请就这两种方案的实现方式的整体逻辑进行分步骤阐述

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请就这两种方案的实现方式的整体逻辑进行分步骤阐述

完成

Comment on lines 76 to 77
无论是C++实现还是组合API实现,其逻辑都是十分简单的,故考虑使用C++编写新的算子以期取得更高的效率。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请就 torch/tensorflow 两种实现方案进行对比评价和对比分析,论述优势、劣势。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请就 torch/tensorflow 两种实现方案进行对比评价和对比分析,论述优势、劣势。

完成

Comment on lines +89 to +92
## API实现方案

参考`paddle.nn.PixelShuffle`来实现`paddle.nn.ChannelShuffle`,顺便实现`paddle.nn.functional.channel_shuffle`。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请对实现方案包含实现文件存放位置,代码内部逻辑进行步骤化的描述

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请对实现方案包含实现文件存放位置,代码内部逻辑进行步骤化的描述

完成

Comment on lines 95 to 98
考虑测试的情况:
- 与PyTorch的结果的一致性;
- 反向传播的正确性;
- 错误检查:`groups`不合法或不整除通道数时能正确抛出异常。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paddle-bot-old
Copy link

你的PR有最新反馈,请及时修改。
There’s the latest feedback about your PR. Please check.

Comment on lines 95 to 97
- PyTorch在C++层面为ChannelShuffle设计了底层算子,执行效率高,但不利于(不参与框架设计)开发者从源码层面了解该API的行为。
- 使用组合API的方式来实现ChannelShuffle是简洁易懂的,开发者可以直接从Python源码的层面来了解该API的行为,但在执行效率上可能会逊色于原生算子实现方案。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的对比分析应当对比pytorch 和 tf 两种方案的优劣,此处两种方法差异不大,则可以直接说明分Channel形式来进行和两种方案一致的的设计

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

- 是否同时支持CPU和GPU平台;
- 测试不同张量类型下的表现;
- 对全部入参进行参数有效性和边界值测试,确定每个入参都可以正确生效;
- 前向计算的正确性(与组合API实现比较、与PyTorch比较);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以在本地和pytorch实现比较,在test由于CI环境没有pytorch,需要使用 numpy 实现进行比较。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

Comment on lines +170 to +171
def extra_repr(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个方法是否需要?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他API均有此方法。

- 前向计算的正确性(与组合API实现比较、与PyTorch比较);
- 反向计算的正确性;
- 当传入的`group`不合法(不是正整数、不整除通道数)时会抛出异常并有友好的提示。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

错误检查:data_format不合法,x不是tensor,x维度检查等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

根据评审意见修改了文档
Copy link
Contributor

@shiyutang shiyutang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 135 to 140
它们的实现定义在`paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h`和`paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h`,它们的注册放在

- `paddle/phi/kernels/cpu/channel_shuffle_kernel.cc`
- `paddle/phi/kernels/cpu/channel_shuffle_grad_kernel.cc`
- `paddle/phi/kernels/gpu/channel_shuffle_kernel.cu`
- `paddle/phi/kernels/gpu/channel_shuffle_grad_kernel.cu`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现内容可以存放在paddle/phi/kernels/cpu(gpu)/ 下,注册则在fluid/oprerator下。如下图所示:
以TraceOp为例:
c++ op开发方式可以参见开发指南:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/new_cpp_op_cn.html

@shiyutang shiyutang merged commit f07eb3a into PaddlePaddle:master Mar 23, 2022
@paddle-bot-old
Copy link

你的PR已合入community库,请进行后续代码开发,并将代码提交至Paddle仓库。
Your PR has been merged into community repository. Please move on coding part and submit your code to corresponding repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants