Skip to content

Commit

Permalink
Adapted to previous changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush1999 committed Oct 5, 2018
1 parent 25dba3d commit ddb7326
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
padtuple(x::Tuple,p::Tuple) = p
padtuple(x::AbstractArray,p) = padtuple(size(x),p)

function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, mode = 0) where A<:AbstractArray
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
x, w, pad = pad_, stride = stride_, dilation = dilation, mode = mode)
x, w, pad = pad_, stride = stride_, dilation = dilation)
end

function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
x, w, pad = pad_, stride = stride_, dilation = dilation)
end

∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
Expand All @@ -39,12 +45,19 @@ end
# N-D dispatch

function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1, mode=0) where T
pad = 0, stride = 1, dilation = 1) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), mode=mode)
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
return y
end

function crossconv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
crossconv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
return y
end

function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
Expand All @@ -62,8 +75,12 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
end

conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1, mode=0) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode)
pad = 0, stride = 1, dilation = 1) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)

crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)

∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
Expand All @@ -74,8 +91,12 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1, mode=0) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=mode)
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)

crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)

∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
Expand Down

0 comments on commit ddb7326

Please sign in to comment.