Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conv with groups when falling in direct backend #468

Merged
merged 10 commits into from
Feb 11, 2023
200 changes: 121 additions & 79 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,31 +166,41 @@ end

# First, we will define mappings from the generic API names to our accelerated backend
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
# im2col + GEMM. Do so in a loop, here:
# im2col + GEMM.
# But we always support a fallback, non-accelerated path, where we use the direct, but
# slow, implementations. These should not typically be used, hence the `@warn`,

# These are the GEMM types we will accelerate with `im2col`
const G = Union{[x[2] for x in gemm_datatype_mappings]...}

for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:conv => :im2col,
)

for (front_name, backend, signature) in (
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
(:conv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
(:conv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
)
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
# im2col-accelerated function forwarding definition

function $(Symbol("$(front_name)!"))(
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
cdims::$(signature[4]);
kwargs...) where {$(signature[5]...)}
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
Expand All @@ -205,87 +215,119 @@ for (front_name, backend) in (
end

# im2col-accelerated function forwarding definition
function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}

dx_cs = Iterators.partition(1:size(out, 4),
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in1, 4),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
Threads.@spawn ∇conv_data_im2col!(dxv, dyv, wv, cdims2; kwargs...)
end
for (front_name, backend, signature) in (
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
(:∇conv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
(:∇conv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
)
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
function $(Symbol("$(front_name)!"))(
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
cdims::$(signature[4]);
kwargs...) where {$(signature[5]...)}
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

return out
end

function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
dw_cs = Iterators.partition(1:size(out, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in2, 4),
channels_out(cdims) ÷ groupcount(cdims))
x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
Threads.@spawn ∇conv_filter_im2col!(dw, x, dy, cdims2; kwargs...)
end
dx_cs = Iterators.partition(1:size(out, 4),
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in1, 4),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
end

return out
return out
end
end
end


for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:depthwiseconv => :im2col,
:∇depthwiseconv_data => :im2col,
:∇depthwiseconv_filter => :im2col,
)

for (front_name, backend, signature) in (
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
(:∇conv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
(:∇conv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
)
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
# im2col-accelerated function forwarding definition
function $(Symbol("$(front_name)!"))(
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
cdims::$(signature[4]);
kwargs...) where {$(signature[5]...)}
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

dw_cs = Iterators.partition(1:size(out, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in2, 4),
channels_out(cdims) ÷ groupcount(cdims))
x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
end

return out
end
end
end

# We always support a fallback, non-accelerated path, where we use the direct, but
# slow, implementations. These should not typically be used, hence the `@warn`,
# but let's go ahead and define them first:
for front_name in (:conv, :∇conv_data, :∇conv_filter,
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)

for (front_name, backend, signature) in (
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
(:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
(:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),

(:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
(:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),

(:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
(:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
)

# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@eval begin
# im2col-accelerated function forwarding definition
function $(Symbol("$(front_name)!"))(
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
in2::AbstractArray{T2,N}, cdims::ConvDims;
kwargs...) where {yT, T1, T2, N}
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
cdims::$(signature[4]);
kwargs...) where {$(signature[5]...)}
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end
$(Symbol("$(front_name)_direct!"))(y, in1, in2, cdims; kwargs...)
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
end
end
end
Expand Down
11 changes: 11 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,17 @@ end
@test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10)
end

# https://github.com/FluxML/NNlib.jl/issues/369
@testset "conv_wrapper with groups - not equal types that trigger direct backend" begin
x = rand(Float32, 10, 10, 32, 8)
w = rand(Float64, 2, 2, 16, 4)
g = 2
@test conv(x, w; groups=g) ≈ conv(x, Float32.(w); groups=g)
@test conv(x, w; stride = (2, 2), pad = (2, 2), groups=g) ≈ conv(x, w; stride = (2, 2), pad = (2, 2), groups=g)
@test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g)
@test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g)
end

@testset "depthwiseconv_wrapper" begin
x = rand(10, 10, 3, 10)
w = rand(2, 2, 3, 3)
Expand Down