Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conv with groups when falling in direct backend #468

Merged
merged 10 commits into from
Feb 11, 2023
99 changes: 97 additions & 2 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,7 @@ end
# We always support a fallback, non-accelerated path, where we use the direct, but
# slow, implementations. These should not typically be used, hence the `@warn`,
# but let's go ahead and define them first:
for front_name in (:conv, :∇conv_data, :∇conv_filter,
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
for front_name in (:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
@eval begin
function $(Symbol("$(front_name)!"))(
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
Expand All @@ -290,6 +289,102 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
end
end

for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:conv => :direct,
# :∇conv_data => :direct,
# :∇conv_filter => :direct,
)

# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
# im2col-accelerated function forwarding definition
function $(Symbol("$(front_name)!"))(
out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
in2::AbstractArray{T2,N}, cdims::C;
kwargs...) where {yT, T1, T2, N, C <: ConvDims}
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
end

return out
end
end
end

# direct function forwarding definition
function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

dx_cs = Iterators.partition(1:size(out, 4),
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in1, 4),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
end

return out
end

function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end
dw_cs = Iterators.partition(1:size(out, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in2, 4),
channels_out(cdims) ÷ groupcount(cdims))
x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
end

return out
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
@eval @non_differentiable $Dims(::Any...)
end
Expand Down
11 changes: 11 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,17 @@ end
@test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10)
end

# https://github.com/FluxML/NNlib.jl/issues/369
@testset "conv_wrapper with groups - not equal types that trigger direct backend" begin
x = rand(Float32, 10, 10, 32, 8)
w = rand(Float64, 2, 2, 16, 4)
g = 2
@test conv(x, w; groups=g) ≈ conv(x, Float32.(w); groups=g)
@test conv(x, w; stride = (2, 2), pad = (2, 2), groups=g) ≈ conv(x, w; stride = (2, 2), pad = (2, 2), groups=g)
@test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g)
@test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g)
end

@testset "depthwiseconv_wrapper" begin
x = rand(10, 10, 3, 10)
w = rand(2, 2, 3, 3)
Expand Down