Skip to content

Commit

Permalink
Merge pull request #71 from ayush1999/dev_mode
Browse files Browse the repository at this point in the history
[WIP] mode parameter for Convolution/cross-convolution
  • Loading branch information
MikeInnes authored Oct 8, 2018
2 parents b653dc1 + 56eee8e commit 519b5c2
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractAr
x, w, pad = pad_, stride = stride_, dilation = dilation)
end

function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
x, w, pad = pad_, stride = stride_, dilation = dilation)
end

∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)

Expand All @@ -39,12 +45,17 @@ end
# N-D dispatch

function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
pad = 0, stride = 1, dilation = 1, flipkernel =0) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
return y
end

function crossconv!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
pad = 0, stride = 1, dilation = 1)
conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1)
end

function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
Expand All @@ -62,8 +73,8 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
end

conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)

∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
Expand All @@ -74,8 +85,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)

∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
Expand All @@ -85,7 +96,7 @@ conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

# Depthwise Conv
# Depthwise Conv

function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
Expand All @@ -96,9 +107,18 @@ function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
end

function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
depthwisecrossconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
end

depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, flipkernel=0) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode= flipkernel)

depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1)

∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
Expand Down

0 comments on commit 519b5c2

Please sign in to comment.