From ef2505da1451e7e4648883aa6e0a80a6894588d4 Mon Sep 17 00:00:00 2001 From: Nikola Date: Tue, 4 Jul 2023 19:04:49 -0400 Subject: [PATCH 1/3] nonzero beta + flipkernel bugfix --- src/impl/conv_direct.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index f2b6ff60e..9950f5295 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -191,6 +191,9 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy = transpose_swapbatch(predilate(dy, stride(cdims))) ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) + if beta!=0 && flipkernel(cdims) + dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] + end conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) if flipkernel(cdims) dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] From aceb24d52f1abc7c06dac34e562989d601695fdb Mon Sep 17 00:00:00 2001 From: Nikola Date: Thu, 6 Jul 2023 01:35:24 -0400 Subject: [PATCH 2/3] conv! alpha/beta tests added, conv_filter_direct flipkernel with view --- src/impl/conv_direct.jl | 11 +++-- test/conv.jl | 95 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 9950f5295..826ef9e35 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -191,12 +191,11 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy = transpose_swapbatch(predilate(dy, stride(cdims))) ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) - if beta!=0 && flipkernel(cdims) - dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] - end - conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) - if flipkernel(cdims) - dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] + dw_ = if flipkernel(cdims) + view(dw, size(dw,1):-1:1, size(dw,2):-1:1, size(dw,3):-1:1, :, :) + else + dw end + conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta) return dw end diff --git a/test/conv.jl b/test/conv.jl index 3037930d7..8edc4bf24 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -310,6 +310,49 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) end end + # Test all in-place implementations/interfaces + convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,] + NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!) + for conv! in convs + if NNlib.is_nnpack_available() + if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w)) + continue + end + end + α, β = 2e0, -1e0 + + @testset "$(conv!)" begin + # First, your basic convolution with no parameters + cdims = DenseConvDims(x, w) + y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7) + + # Next, test convolution on views and alternate datatypes: + @test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7) + @test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7) + + # Next, introduce stride: + cdims = DenseConvDims(x, w; stride=2) + y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7) + + # Next, introduce dilation: + cdims = DenseConvDims(x, w; dilation=2) + y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7) + + # Next, introduce padding: + cdims = DenseConvDims(x, w; padding=1) + y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7) + + # Next, test crosscor/conv with a flipped kernel + cdims = DenseConvDims(x, w; flipkernel=true) + y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1) + @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7) + end + end + # Test all implementations/interfaces for (∇conv_filter, ∇conv_data) in ( (NNlib.∇conv_filter, NNlib.∇conv_data), @@ -355,6 +398,58 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7) end end + + # Test all in-place implementations/interfaces + for (∇conv_filter!, ∇conv_data!) in ( + (NNlib.∇conv_filter!, NNlib.∇conv_data!), + (NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!), + (NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!), + ) + #α, β = 2*rand(rng) - 1, 2*rand(rng) - 1 + α, β = 2e0, -1e0 + flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!) + + @testset "$(∇conv_filter!)/$(∇conv_data!)" begin + # First, your basic convolution with no parameters + cdims = DenseConvDims(x, w) + dy = NNlib.conv(x, w, cdims) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag + + # Next, test convolution on views and alternate datatypes: + @test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag + + @test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag + + # Next, introduce stride: + cdims = DenseConvDims(x, w; stride=2) + dy = NNlib.conv(x, w, cdims) + flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_ + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag + + # Next, introduce dilation: + cdims = DenseConvDims(x, w; dilation=2) + dy = NNlib.conv(x, w, cdims) + flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3 + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_ + + # Next, introduce padding: + cdims = DenseConvDims(x, w; padding=1) + dy = NNlib.conv(x, w, cdims) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag + + # Next, test crosscor/conv with a flipped kernel + cdims = DenseConvDims(x, w; flipkernel=true) + dy = NNlib.conv(x, w, cdims) + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag + end + end end end end From 8f57824eb4953c5b30e47f2110a33ab14e47346b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nikola=20Janju=C5=A1evi=C4=87?= Date: Sat, 8 Jul 2023 15:53:13 -0400 Subject: [PATCH 3/3] Update src/impl/conv_direct.jl Co-authored-by: Brian Chen --- src/impl/conv_direct.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 826ef9e35..9f12f1dc9 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -192,7 +192,7 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) dw_ = if flipkernel(cdims) - view(dw, size(dw,1):-1:1, size(dw,2):-1:1, size(dw,3):-1:1, :, :) + view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :) else dw end