Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conv with groups when falling in direct backend #468

Merged
merged 10 commits into from
Feb 11, 2023
60 changes: 58 additions & 2 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ end
for (front_name, backend) in (
# This maps from public, front-facing name, to internal backend name
:conv => :direct,
:∇conv_data => :direct,
:∇conv_filter => :direct,
# :∇conv_data => :direct,
# :∇conv_filter => :direct,
)

# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
Expand Down Expand Up @@ -329,6 +329,62 @@ for (front_name, backend) in (
end
end

# direct function forwarding definition
function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end

dx_cs = Iterators.partition(1:size(out, 4),
channels_in(cdims) ÷ groupcount(cdims))
w_cs = Iterators.partition(1:size(in2, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in1, 4),
channels_out(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
end

return out
end

function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
end
dw_cs = Iterators.partition(1:size(out, 5),
channels_out(cdims) ÷ groupcount(cdims))
dy_cs = Iterators.partition(1:size(in2, 4),
channels_out(cdims) ÷ groupcount(cdims))
x_cs = Iterators.partition(1:size(in1, 4),
channels_in(cdims) ÷ groupcount(cdims))
cdims2 = basetype(C)(cdims,
G = 1,
C_in = channels_in(cdims) ÷ groupcount(cdims),
C_out = channels_out(cdims) ÷ groupcount(cdims))

Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
end

return out
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
@eval @non_differentiable $Dims(::Any...)
end
Expand Down