diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index f2b6ff60e..9f12f1dc9 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -191,9 +191,11 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy = transpose_swapbatch(predilate(dy, stride(cdims))) ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) - conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) - if flipkernel(cdims) - dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] + dw_ = if flipkernel(cdims) + view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :) + else + dw end + conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta) return dw end diff --git a/test/conv.jl b/test/conv.jl index 3037930d7..8edc4bf24 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -310,6 +310,49 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) end end + # Test all in-place implementations/interfaces + convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,] + NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!) + for conv! in convs + if NNlib.is_nnpack_available() + if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w)) + continue + end + end + α, β = 2e0, -1e0 + + @testset "$(conv!)" begin + # First, your basic convolution with no parameters + cdims = DenseConvDims(x, w) + y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7) + + # Next, test convolution on views and alternate datatypes: + @test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7) + @test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7) + + # Next, introduce stride: + cdims = DenseConvDims(x, w; stride=2) + y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7) + + # Next, introduce dilation: + cdims = DenseConvDims(x, w; dilation=2) + y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7) + + # Next, introduce padding: + cdims = DenseConvDims(x, w; padding=1) + y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7) + + # Next, test crosscor/conv with a flipped kernel + cdims = DenseConvDims(x, w; flipkernel=true) + y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7) + end + end + # Test all implementations/interfaces for (∇conv_filter, ∇conv_data) in ( (NNlib.∇conv_filter, NNlib.∇conv_data), @@ -355,6 +398,58 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7) end end + + # Test all in-place implementations/interfaces + for (∇conv_filter!, ∇conv_data!) in ( + (NNlib.∇conv_filter!, NNlib.∇conv_data!), + (NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!), + (NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!), + ) + #α, β = 2*rand(rng) - 1, 2*rand(rng) - 1 + α, β = 2e0, -1e0 + flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!) + + @testset "$(∇conv_filter!)/$(∇conv_data!)" begin + # First, your basic convolution with no parameters + cdims = DenseConvDims(x, w) + dy = NNlib.conv(x, w, cdims) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag + + # Next, test convolution on views and alternate datatypes: + @test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag + + @test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag + + # Next, introduce stride: + cdims = DenseConvDims(x, w; stride=2) + dy = NNlib.conv(x, w, cdims) + flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_ + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag + + # Next, introduce dilation: + cdims = DenseConvDims(x, w; dilation=2) + dy = NNlib.conv(x, w, cdims) + flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3 + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_ + + # Next, introduce padding: + cdims = DenseConvDims(x, w; padding=1) + dy = NNlib.conv(x, w, cdims) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag + + # Next, test crosscor/conv with a flipped kernel + cdims = DenseConvDims(x, w; flipkernel=true) + dy = NNlib.conv(x, w, cdims) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag + end + end end end end