-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dilations for conv2d and optimize conv2d code #5472
Conversation
10101ea
to
27805c2
Compare
607e0e8
to
520dec7
Compare
520dec7
to
97e9dd7
Compare
b125570
to
bd73642
Compare
b2d5245
to
caf24f3
Compare
caf24f3
to
93551bd
Compare
a5bbe8d
to
3e60b6b
Compare
a3e15fd
to
7d73b8f
Compare
paddle/operators/conv_op.h
Outdated
filter_1 &= (static_cast<int>(filter_dim[j]) == 1); | ||
strides_1 &= (strides[j] == 1); | ||
padding_0 &= (paddings[j] == 0); | ||
dilation_1 &= (dilations[j] == 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
&=
-> &&=
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有这种表示&&=
,我改成了strides_1 = strides_1 && (strides[j] == 1)
paddle/operators/conv_op.h
Outdated
vol2col(context.device_context(), in_slice, col, strides[0], | ||
strides[1], strides[2], paddings[0], paddings[1], | ||
paddings[2]); | ||
if (!not_expand) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
! not expand = expand, the logic is a little complex,
How about rename NotExpand to IsExpand? Then return True, means that it needs to expand, ortherwise, not expand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/conv_op.h
Outdated
for (int i = 0; i < batch_size; i++) { | ||
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); | ||
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); | ||
for (int g = 0; g < groups; g++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (int i = 0; i < batch_size; i++) {
// ....
for (int g = 0; g < groups; g++) {
if(!IsExpand) {
ShareDataWith();
} else if () {
im2col();
} else if () {
im2vol();
}
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/conv_op.h
Outdated
|
||
math::matmul<Place, T>(context.device_context(), filter_slice, true, | ||
out_grad_slice, false, T(1.0), &col_matrix, | ||
T(0.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code structure is same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/math/im2col.cc
Outdated
1, | ||
col_width, | ||
"col_width and padding(padding_left, padding_right) are " | ||
"inconsistent."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写functor的时候,我也在考虑functor里面是否还有必要再次check shape的正确性,因为要么是在Op里计算得到的,要么InferShape里也已经check过。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我也有想过这些,比如在Op里面检测过了,在GradOp中就不用检测了吧
paddle/operators/math/im2col.cc
Outdated
int padding_down, int padding_left, int padding_right) { | ||
int dilation_h, int dilation_w, int stride_height, | ||
int stride_width, int padding_up, int padding_down, | ||
int padding_left, int padding_right) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe std::vector<int>& dilations
, std::vector<int>& strides
, std::vector<int>& paddings
are short. And the op also uses std::vector<int>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/conv_op.h
Outdated
int output_size = (input_size + padding_up + padding_down - | ||
(dilation * (filter_size - 1) + 1)) / | ||
stride + | ||
1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + padding_up + padding_down - dkernel)/stride + 1;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/math/vol2col.cc
Outdated
1, | ||
output_width, | ||
"input_width and output_width are " | ||
"Mismatching."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, whether it needs to check again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can write in this way first, and discuss it later. Because other functors also have similar problem.
paddle/operators/math/vol2col.cu
Outdated
output_height + | ||
h_col) * | ||
output_width + | ||
w_col; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_col_index的计算太长,不容易看清楚。一些计算可以移到各自循环里。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把这个公式分成了两个,现在可能会好一点
1607de4
to
8fffa9e
Compare
8fffa9e
to
31dc019
Compare
fix #5495
fix #5507
fix #5550