Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nonzero beta + flipkernel bugfix #519

Merged
merged 4 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/impl/conv_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
dy = transpose_swapbatch(predilate(dy, stride(cdims)))
ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),
stride=dilation(cdims))
conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta)
if flipkernel(cdims)
dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :]
dw_ = if flipkernel(cdims)
view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :)
else
dw
end
conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta)
return dw
end
95 changes: 95 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,49 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
end
end

# Test all in-place implementations/interfaces
convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,]
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!)
for conv! in convs
if NNlib.is_nnpack_available()
if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
continue
end
end
α, β = 2e0, -1e0

@testset "$(conv!)" begin
# First, your basic convolution with no parameters
cdims = DenseConvDims(x, w)
y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)

# Next, test convolution on views and alternate datatypes:
@test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)
@test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7)

# Next, introduce stride:
cdims = DenseConvDims(x, w; stride=2)
y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7)

# Next, introduce dilation:
cdims = DenseConvDims(x, w; dilation=2)
y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7)

# Next, introduce padding:
cdims = DenseConvDims(x, w; padding=1)
y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7)

# Next, test crosscor/conv with a flipped kernel
cdims = DenseConvDims(x, w; flipkernel=true)
y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7)
end
end

# Test all implementations/interfaces
for (∇conv_filter, ∇conv_data) in (
(NNlib.∇conv_filter, NNlib.∇conv_data),
Expand Down Expand Up @@ -355,6 +398,58 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
@test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7)
end
end

# Test all in-place implementations/interfaces
for (∇conv_filter!, ∇conv_data!) in (
(NNlib.∇conv_filter!, NNlib.∇conv_data!),
(NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!),
(NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!),
)
#α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
α, β = 2e0, -1e0
flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)

@testset "$(∇conv_filter!)/$(∇conv_data!)" begin
# First, your basic convolution with no parameters
cdims = DenseConvDims(x, w)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag

# Next, test convolution on views and alternate datatypes:
@test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag

@test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag

# Next, introduce stride:
cdims = DenseConvDims(x, w; stride=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag

# Next, introduce dilation:
cdims = DenseConvDims(x, w; dilation=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_

# Next, introduce padding:
cdims = DenseConvDims(x, w; padding=1)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag

# Next, test crosscor/conv with a flipped kernel
cdims = DenseConvDims(x, w; flipkernel=true)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag
end
end
end
end
end
Expand Down