diff --git a/src/conv.jl b/src/conv.jl index 0682c3074..9579a571b 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -30,6 +30,12 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractAr x, w, pad = pad_, stride = stride_, dilation = dilation) end +function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray + pad_, stride_ = padtuple(x, pad), padtuple(x, stride) + crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), + x, w, pad = pad_, stride = stride_, dilation = dilation) +end + ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray = ∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation) @@ -39,12 +45,17 @@ end # N-D dispatch function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3}; - pad = 0, stride = 1, dilation = 1) where T + pad = 0, stride = 1, dilation = 1, flipkernel =0) where T args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w)) - conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1)) + conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel) return y end +function crossconv!(y::AbstractArray, x::AbstractArray, w::AbstractArray; + pad = 0, stride = 1, dilation = 1) + conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1) +end + function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3}; pad = 0, stride = 1, dilation = 1) where T @@ -62,8 +73,8 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3}, end conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, dilation = 1) where T = - conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation) + pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = + conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) ∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; pad = 0, stride = 1, dilation = 1) where T = @@ -74,8 +85,8 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation) conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; - pad = 0, stride = 1, dilation = 1) where T = - conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation) + pad = 0, stride = 1, dilation = 1, flipkernel=0) where T = + conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel) ∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; pad = 0, stride = 1, dilation = 1) where T = @@ -85,7 +96,7 @@ conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; pad = 0, stride = 1, dilation = 1) where T = conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation) -# Depthwise Conv + # Depthwise Conv function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride) ((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4]) @@ -96,9 +107,18 @@ function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_) end +function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray + pad_, stride_ = padtuple(x, pad), padtuple(x, stride) + depthwisecrossconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_) +end + depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; + pad = 0, stride = 1, flipkernel=0) where T = + depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode= flipkernel) + +depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; pad = 0, stride = 1) where T = - depthwiseconv2d!(y, x, w, padding = pad, stride = stride) + depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1) ∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray = ∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)