Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support for asymmetrical padding #84

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 61 additions & 35 deletions src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ function psize(p, x)
end

function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int, height::Int, channels::Int,
kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int,
kernel_w::Int, kernel_h::Int, pad_w::Tuple{Int,Int}, pad_h::Tuple{Int,Int}, stride_w::Int, stride_h::Int,
dil_w::Int, dil_h::Int, mode::Int) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
height_col = div(height + sum(pad_h) - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + sum(pad_w) - (kernel_w - 1) * dil_w - 1, stride_w) + 1
channels_col = channels * kernel_h * kernel_w

#pragma omp parallel for
Expand All @@ -29,11 +29,11 @@ function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int
end
for h = 1:height_col
for w = 1:width_col
h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h
w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w
if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width
h_pad_before = (h - 1) * stride_h - pad_h[1] + h_offset * dil_h
w_pad_before = (w - 1) * stride_w - pad_w[1] + w_offset * dil_w
if h_pad_before >= 0 && h_pad_before < height && w_pad_before >= 0 && w_pad_before < width
col[((c - 1)*height_col+h-1) * width_col + w] =
img[(c_im * height + h_pad) * width + w_pad + 1]
img[(c_im * height + h_pad_before) * width + w_pad_before + 1]
else
col[((c - 1)*height_col+h - 1) * width_col + w] = 0
end
Expand All @@ -43,11 +43,11 @@ function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int
end

function col2im_2d!(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int, height::Int,
channels::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int,
channels::Int, kernel_w::Int, kernel_h::Int, pad_w::Tuple{Int,Int}, pad_h::Tuple{Int,Int}, stride_w::Int,
stride_h::Int, dil_w::Int, dil_h::Int, mode::Int) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
height_col = div(height + sum(pad_h) - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + sum(pad_w) - (kernel_w - 1) * dil_w - 1, stride_w) + 1
channels_col = channels * kernel_h * kernel_w

fill!(img, 0)
Expand All @@ -61,23 +61,23 @@ function col2im_2d!(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int
h_offset = kernel_h - 1 - h_offset
end
for h = 1:height_col, w = 1:width_col
h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h
w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w
if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width
h_pad_before = (h - 1) * stride_h - pad_h[1] + h_offset * dil_h
w_pad_before = (w - 1) * stride_w - pad_w[1] + w_offset * dil_w
if h_pad_before >= 0 && h_pad_before < height && w_pad_before >= 0 && w_pad_before < width
cval::T = col[((c - 1) * height_col + h - 1) * width_col + w]
img[(c_im * height + h_pad) * width + w_pad + 1] += cval
img[(c_im * height + h_pad_before) * width + w_pad_before + 1] += cval
end
end
end
end

function im2col_3d!(img::AbstractArray{T,4}, col::AbstractArray{T,2}, width::Int, height::Int, depth::Int,
channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, pad_w::Int, pad_h::Int, pad_d::Int,
stride_w::Int, stride_h::Int, stride_d::Int, dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T
channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, pad_w::Tuple{Int,Int}, pad_h::Tuple{Int,Int}, pad_d::Tuple{Int,Int},
stride_w::Int, stride_h::Int, stride_d::Int, dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
depth_col = div(depth + 2pad_d - (kernel_d - 1) * dil_d - 1, stride_d) + 1
height_col = div(height + sum(pad_h) - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + sum(pad_w) - (kernel_w - 1) * dil_w - 1, stride_w) + 1
depth_col = div(depth + sum(pad_d) - (kernel_d - 1) * dil_d - 1, stride_d) + 1
channels_col = channels * kernel_h * kernel_w * kernel_d


Expand All @@ -93,13 +93,13 @@ function im2col_3d!(img::AbstractArray{T,4}, col::AbstractArray{T,2}, width::Int
d_offset = kernel_d - 1 - d_offset
end
for d = 1:depth_col, h = 1:height_col, w = 1:width_col
d_pad = (d - 1) * stride_d - pad_d + d_offset * dil_d
h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h
w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w
if d_pad >= 0 && d_pad < depth && h_pad >= 0 && h_pad < height &&
w_pad >= 0 && w_pad < width
d_pad_before = (d - 1) * stride_d - pad_d[1] + d_offset * dil_d
h_pad_before = (h - 1) * stride_h - pad_h[1] + h_offset * dil_h
w_pad_before = (w - 1) * stride_w - pad_w[1] + w_offset * dil_w
if d_pad_before >= 0 && d_pad_before < depth && h_pad_before >= 0 && h_pad_before < height &&
w_pad_before >= 0 && w_pad_before < width
col[(((c - 1) * depth_col + d - 1) * height_col + h - 1) * width_col + w] =
img[((c_im * depth + d_pad) * height + h_pad) * width + w_pad + 1]
img[((c_im * depth + d_pad_before) * height + h_pad_before) * width + w_pad_before + 1]
else
col[(((c - 1) * depth_col + d - 1) * height_col + h - 1) * width_col + w] = 0
end
Expand All @@ -109,12 +109,12 @@ end

function col2im_3d!(col::AbstractArray{T,2}, img::AbstractArray{T,4}, width::Int, height::Int,
depth::Int, channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int,
pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int,
pad_w::Tuple{Int,Int}, pad_h::Tuple{Int,Int}, pad_d::Tuple{Int,Int}, stride_w::Int, stride_h::Int, stride_d::Int,
dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T

height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
depth_col = div(depth + 2pad_d - (kernel_d - 1) * dil_d - 1, stride_d) + 1
height_col = div(height + sum(pad_h) - (kernel_h - 1) * dil_h - 1, stride_h) + 1
width_col = div(width + sum(pad_w) - (kernel_w - 1) * dil_w - 1, stride_w) + 1
depth_col = div(depth + sum(pad_d) - (kernel_d - 1) * dil_d - 1, stride_d) + 1
channels_col = channels * kernel_h * kernel_w * kernel_d

fill!(img, 0)
Expand All @@ -132,13 +132,13 @@ function col2im_3d!(col::AbstractArray{T,2}, img::AbstractArray{T,4}, width::Int
end

for d = 1:depth_col, h = 1:height_col, w = 1:width_col
d_pad = (d - 1) * stride_d - pad_d + d_offset * dil_d
h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h
w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w
if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width &&
d_pad >= 0 && d_pad < depth
d_pad_before = (d - 1) * stride_d - pad_d[1] + d_offset * dil_d
h_pad_before = (h - 1) * stride_h - pad_h[1] + h_offset * dil_h
w_pad_before = (w - 1) * stride_w - pad_w[1] + w_offset * dil_w
if h_pad_before >= 0 && h_pad_before < height && w_pad_before >= 0 && w_pad_before < width &&
d_pad_before >= 0 && d_pad_before < depth
cval::T = col[(((c - 1) * depth_col + d - 1) * height_col + h - 1) * width_col + w]
iidx = ((c_im * depth + d_pad) * height + h_pad) * width + w_pad + 1
iidx = ((c_im * depth + d_pad_before) * height + h_pad_before) * width + w_pad_before + 1
#pragma omp atomic
img[iidx] += cval
end
Expand Down Expand Up @@ -328,6 +328,11 @@ function conv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstra
return dx
end

function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int, height::Int, channels::Int,
kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int,dil_w::Int, dil_h::Int, mode::Int) where T
im2col_2d!(img, col, width, height, channels,kernel_w, kernel_h, (pad_w,pad_w), (pad_h,pad_h), stride_w, stride_h,dil_w, dil_h, mode)
end

function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
Wx,Hx,Cx,Nx = size(x)
Expand All @@ -346,6 +351,12 @@ function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArr
return x2
end

function col2im_2d!(col::AbstractArray{T,2}, img::AbstractArray{T,3}, width::Int, height::Int,
channels::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int,
stride_h::Int, dil_w::Int, dil_h::Int, mode::Int) where T
col2im_2d!(col, img, width, height, channels, kernel_w, kernel_h, (pad_w,pad_w), (pad_h,pad_h), stride_w, stride_h, dil_w, dil_h, mode)
end

function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
Wx,Hx,Cx,Nx = size(x)
Expand Down Expand Up @@ -440,6 +451,12 @@ function conv3d_grad_x!(dx::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abstra
return dx
end

function im2col_3d!(img::AbstractArray{T,4}, col::AbstractArray{T,2}, width::Int, height::Int, depth::Int,
channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int, pad_w::Int, pad_h::Int, pad_d::Int,
stride_w::Int, stride_h::Int, stride_d::Int, dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T
im2col_3d!(img, col, width, height, depth, channels, kernel_w, kernel_h, kernel_d, (pad_w,pad_w), (pad_h,pad_h), (pad_d,pad_d), stride_w, stride_h, stride_d, dil_w, dil_h, dil_d, mode)
end

function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, p3::Int, s1::Int, s2::Int,
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
Expand All @@ -450,6 +467,15 @@ function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
return x2
end

function col2im_3d!(col::AbstractArray{T,2}, img::AbstractArray{T,4}, width::Int, height::Int,
depth::Int, channels::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int,
pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int,
dil_w::Int, dil_h::Int, dil_d::Int, mode::Int) where T

col2im_3d!(col, img, width, height, depth, channels, kernel_w, kernel_h, kernel_d, (pad_w,pad_w), (pad_h,pad_h), (pad_d,pad_d), stride_w, stride_h, stride_d,
dil_w, dil_h, dil_d, mode)
end

function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArray{T,2},
n::Int, p1::Int, p2::Int, p3::Int, s1::Int, s2::Int,
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
Expand Down