Skip to content

Commit

Permalink
Merge pull request #99 from thebhatman/patch-1
Browse files Browse the repository at this point in the history
Changed output size calculation to support kernel size such as (x,y)
  • Loading branch information
DhairyaLGandhi authored Mar 28, 2019
2 parents 2bd7e8a + 7e19660 commit 11f840d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ img_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = I
# Calculate the output dimensions of this convolution
function output_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F}
O_w = div(I[1] + P[1] + P[2] - (K[1] - 1) * D[1] - 1, S[1]) + 1
O_h = div(I[2] + P[3] + P[4] - (K[1] - 1) * D[1] - 1, S[1]) + 1
O_h = div(I[2] + P[3] + P[4] - (K[2] - 1) * D[2] - 1, S[2]) + 1
return (O_w, O_h)
end
kernel_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = K
Expand Down
13 changes: 13 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@ using NNlib: conv, crosscor, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool,
@testset "conv2d" begin
x = reshape(Float64[1:20;], 5, 4, 1, 1)
w = reshape(Float64[1:4;], 2, 2, 1, 1)
w1 = reshape(Float64[1:6;], 2, 3, 1, 1)
w2 = reshape(Float64[1:6;], 3, 2, 1, 1)

@test dropdims(conv(x, w1), dims = (3,4)) == [
95.0 200.0;
116.0 221.0;
137.0 242.0;
158.0 263.0]

@test dropdims(conv(x, w2), dims = (3,4)) == [
68.0 173.0 278.0;
89.0 194.0 299.0;
110.0 215.0 320.0]

@test dropdims(conv(x, w), dims = (3,4)) == [
29 79 129;
Expand Down

0 comments on commit 11f840d

Please sign in to comment.