diff --git a/src/conv.jl b/src/conv.jl index 9c6c51ae3..32db7bd99 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -24,10 +24,16 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x))) padtuple(x::Tuple,p::Tuple) = p padtuple(x::AbstractArray,p) = padtuple(size(x),p) -function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, mode = 0) where A<:AbstractArray +function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray pad_, stride_ = padtuple(x, pad), padtuple(x, stride) conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), - x, w, pad = pad_, stride = stride_, dilation = dilation, mode = mode) + x, w, pad = pad_, stride = stride_, dilation = dilation) +end + +function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray + pad_, stride_ = padtuple(x, pad), padtuple(x, stride) + crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), + x, w, pad = pad_, stride = stride_, dilation = dilation) end ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray = @@ -39,12 +45,19 @@ end # N-D dispatch function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3}; - pad = 0, stride = 1, dilation = 1, mode=0) where T + pad = 0, stride = 1, dilation = 1) where T args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w)) - conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), mode=mode) + conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1)) return y end +function crossconv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3}; + pad = 0, stride = 1, dilation = 1) where T + args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w)) + crossconv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1)) +return y +end + function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3}; pad = 0, stride = 1, dilation = 1) where T @@ -62,8 +75,12 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3}, end conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; - pad = 0, stride = 1, dilation = 1, mode=0) where T = - conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode) + pad = 0, stride = 1, dilation = 1) where T = + conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation) + +crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; + pad = 0, stride = 1, dilation = 1) where T = + conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1) ∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; pad = 0, stride = 1, dilation = 1) where T = @@ -74,8 +91,12 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation) conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; - pad = 0, stride = 1, dilation = 1, mode=0) where T = - conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode) + pad = 0, stride = 1, dilation = 1) where T = + conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation) + +crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; + pad = 0, stride = 1, dilation = 1) where T = + conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1) ∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}; pad = 0, stride = 1, dilation = 1) where T =