Skip to content

Commit

Permalink
Changes for depthwiseconv
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush1999 committed Oct 4, 2018
1 parent b94257f commit 25dba3d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ end
# N-D dispatch

function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
pad = 0, stride = 1, dilation = 1, mode=0) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), mode=mode)
return y
end

Expand Down Expand Up @@ -74,8 +74,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
pad = 0, stride = 1, dilation = 1, mode=0) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode)

∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
Expand All @@ -91,14 +91,14 @@ function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
end

function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
function depthwiseconv(x::A, w::A; pad = 0, stride = 1, mode=0) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_, mode=mode)
end

depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
pad = 0, stride = 1, mode=0) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode=mode)

∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
Expand Down

0 comments on commit 25dba3d

Please sign in to comment.