Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] mode parameter for Convolution/cross-convolution #71

Merged
merged 5 commits into from
Oct 8, 2018
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,24 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractAr
x, w, pad = pad_, stride = stride_, dilation = dilation)
end

function crossconv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
crossconv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
x, w, pad = pad_, stride = stride_, dilation = dilation)
end

∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)

∇crossconv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
∇crossconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)

∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)

∇crossconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
∇crossconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)

# N-D dispatch

function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
Expand All @@ -45,6 +57,13 @@ function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,
return y
end

function crossconv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
crossconv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's crosscor. It should also be a thin wrapper that doesn't need as much duplication as you've added here; if you just call conv! here with the right mode, for example, you won't need specific 2D and 3D wrappers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But conv! doesn't take the mode parameter in its arguments at the moment, (and you also advised me not to expose mode in conv!).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be ok to have it in conv! (perhaps as flipkernel) but just not document it. That makes it a bit easier to support crosscor! without a lot of duplication. You can also then avoid the gradient wrappers, since we can just support flipkernel in Flux.

return y
end

function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
Expand All @@ -53,6 +72,14 @@ function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
return dw
end

function ∇crossconv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
∇crossconv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
return dw
end

function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
Expand All @@ -61,31 +88,64 @@ function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
return dx
end

function ∇crossconv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
x::AbstractArray{T,3}, w::AbstractArray{T,3};
pad = 0, stride = 1, dilation = 1) where T
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
∇crossconv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1))
return dx
end

conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)

crossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)

∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)

∇crossconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)

∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

∇crossconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1, dilation = 1) where T =
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)

conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)

crossconv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=1)

∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)

∇crossconv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)

∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)

# Depthwise Conv
∇crossconv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
pad = 0, stride = 1, dilation = 1) where T =
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=1)


# Depthwise Conv

function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
Expand All @@ -96,24 +156,47 @@ function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
end

function depthwisecrossconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
depthwisecrossconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
end

depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)

depthwisecrossconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode=1)

∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)

∇depthwisecrossconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwisecrossconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)

∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)

∇depthwisecrossconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
∇depthwisecrossconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)

∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)

∇depthwisecrossconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, mode=1)

∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)

∇depthwisecrossconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
pad = 0, stride = 1) where T =
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, mode=1)

# Pooling

function pdims(dims::Dims{N}, window, padding, stride) where N
Expand Down