Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugged fallback in grouped Convs #369

Closed
SimonCoste opened this issue Dec 27, 2021 · 3 comments · Fixed by #468
Closed

Bugged fallback in grouped Convs #369

SimonCoste opened this issue Dec 27, 2021 · 3 comments · Fixed by #468

Comments

@SimonCoste
Copy link

SimonCoste commented Dec 27, 2021

Hi,

I defined a grouped convolution in Flux using C = Conv((1,1), 2=>2, groups=2). When I feed non-float arrays to this convolutional layer, eg with C(rand(10,10,2,1)), I first get a Slow fallback warning, and then an AssertionError: DimensionMismatch, see the stacktrace below.

This error should not be here, and is very misleading since it is by no means a DimensionMismatch problem - the dimensions are ok - but it is apparently linked to the datatypes : indeed, accordingly to the warning, the error disappears when I use C(rand(Float32, 10,10,2,1)).

Classical (non-grouped) convolutions do not display these kind of errors.

julia> C = Conv((1,1), 2=>2, groups=2)
Conv((1, 1), 1 => 2)  # 4 parameters

julia> C(rand(10,10,2,1))
┌ Warning: Slow fallback implementation invoked for conv!  You probably don't want this; check your datatypes.
│   yT = Float64
│   T1 = Float64
│   T2 = Float32
└ @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:291
ERROR: AssertionError: DimensionMismatch("Data input channel count (2 vs. 2)")
Stacktrace:
  [1] check_dims(x::NTuple{5, Int64}, w::NTuple{5, Int64}, y::NTuple{5, Int64}, cdims::DenseConvDims{3, (1, 1, 1), 2, 2, 2, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false})
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/dim_helpers/DenseConvDims.jl:73
  [2] conv_direct!(y::Array{Float64, 5}, x::Array{Float64, 5}, w::Array{Float32, 5}, cdims::DenseConvDims{3, (1, 1, 1), 2, 2, 2, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false}; alpha::Float64, beta::Bool)
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/impl/conv_direct.jl:51
  [3] conv_direct!
    @ ~/.julia/packages/NNlib/P9BhZ/src/impl/conv_direct.jl:51 [inlined]
  [4] conv!(y::Array{Float64, 5}, in1::Array{Float64, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, (1, 1, 1), 2, 2, 2, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:293
  [5] conv!(y::Array{Float64, 5}, in1::Array{Float64, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, (1, 1, 1), 2, 2, 2, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false})
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:291
  [6] conv!(y::Array{Float64, 4}, x::Array{Float64, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, (1, 1), 2, 2, 2, (1, 1), (0, 0, 0, 0), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:151
  [7] conv!
    @ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:151 [inlined]
  [8] conv(x::Array{Float64, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, (1, 1), 2, 2, 2, (1, 1), (0, 0, 0, 0), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:91
  [9] conv(x::Array{Float64, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, (1, 1), 2, 2, 2, (1, 1), (0, 0, 0, 0), (1, 1), false})
    @ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:89
 [10] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Float64, 4})
    @ Flux ~/.julia/packages/Flux/ZnXxS/src/layers/conv.jl:163
 [11] top-level scope
    @ REPL[5]:1
 [12] top-level scope
    @ ~/.julia/packages/CUDA/YpW0k/src/initialization.jl:52

See also the Julialang discussion.

@SimonCoste SimonCoste changed the title Super misleading fallback in grouped Convs Bugged fallback in grouped Convs Dec 27, 2021
@DhairyaLGandhi
Copy link
Member

Good eye, we should catch this error in NNlib at the same level as the non grouped versions

@gabrielpreviato
Copy link
Contributor

I added a PR that should fix this, #468.

@ToucheSir ToucheSir linked a pull request Mar 3, 2023 that will close this issue
2 tasks
@ToucheSir
Copy link
Member

Closing as fixed by #468.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants