diff --git a/src/Flux.jl b/src/Flux.jl index eccdd6a7e5..a041a69a8d 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, +export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 99a04890ba..76923f3f0b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,4 @@ -using NNlib: conv, ∇conv_data, depthwiseconv +using NNlib: conv, ∇conv_data, depthwiseconv, DenseConvDims @generated sub2(::Val{N}) where N = :(Val($(N-2))) @@ -171,6 +171,64 @@ end (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) +""" + CrossCor(size, in=>out) + CrossCor(size, in=>out, relu) + +Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. + +Data should be stored in WHCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad`, `stride` and `dilation`. +""" +struct CrossCor{N,F,A,V} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{N,Int} + dilation::NTuple{N,Int} +end + +CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} = + CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) + +CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = + CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ, + stride = stride, pad = pad, dilation = dilation) + +@treelike CrossCor + +function crosscor(x, w, ddims::DenseConvDims) + ddims = DenseConvDims(ddims, F=true) + return conv(x, w, ddims) +end + +function (c::CrossCor)(x) + # TODO: breaks gpu broadcast :( + # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) + σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + ddims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) + σ.(crosscor(x, c.weight, ddims) .+ b) +end + +function Base.show(io::IO, l::CrossCor) + print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) + print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + +(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + """ MaxPool(k) @@ -213,4 +271,4 @@ MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N = function Base.show(io::IO, m::MeanPool) print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") -end +end \ No newline at end of file diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 1748ed5e07..2b80c0a144 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4)) l = c(gpu(rand(10,10,3,2))) Flux.back!(sum(l)) +c = gpu(CrossCor((2,2),3=>4)) +l = c(gpu(rand(10,10,3,2))) +Flux.back!(sum(l)) + end if CuArrays.libcudnn != nothing diff --git a/test/layers/conv.jl b/test/layers/conv.jl index d28b099a68..247470fc24 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -43,3 +43,16 @@ end @test size(m4(r), 3) == 3 end + +@testset "CrossCor" begin + r = rand(28, 28, 1, 1) + m = Chain(CrossCor((2, 2), 1=>16, relu), + MaxPool((2,2)), + CrossCor((2, 2), 16=>8, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Dense(288, 10), softmax) + + @test size(m(r)) == (10, 5) + + end