Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic example from Migrating from Flux to Lux is broken || normalization issue #94

Closed
gabrevaya opened this issue Jul 20, 2022 · 2 comments
Labels
documentation Improvements or additions to documentation

Comments

@gabrevaya
Copy link
Contributor

The Lux version of the first exemplre from in Migrating from Flux to Lux from the docs is broken. The reason is that the input x is a Matrix{Float64}, but when setting up the model ps, st = Lux.setup(rng, model) the parameters are Matrix{Float32} by default, and normalization function requires the eltypes of all their arguments to be the same. A quick fix for the example in the docs to work is just to initialize x as with Float32s: x = randn(rng, Float32, 2, 4). However I think it would be good the fix the normalization issue eventually. Also, it would be nice for keeping that line of code between Flux and Lux the same. Or we could make x = randn(rng, Float32, 2, 4) for both Lux and Flux.

For fixing the normalization issue, an some options could be to add a method with a different parametrized type for the first argument (which is x) or try to promote/convert some of the types if they are not the same. What do you think it would be the best way to handle this?

Code and error message
using Lux, Random, NNlib, Zygote

model = Chain(Dense(2 => 4), BatchNorm(4, relu), Dense(4 => 2))
rng = Random.default_rng()
x = randn(rng, 2, 4)
ps, st = Lux.setup(rng, model)
model(x, ps, st)
ERROR: MethodError: no method matching normalization(::Matrix{Float64}, ::Vector{Float32}, ::Vector{Float32}, ::Vector{Float32}, ::Vector{Float32}, ::typeof(relu), ::Vector{Int64}, ::Val{true}, ::Float32, ::Float32)
Closest candidates are:
  normalization(::AbstractArray{T, N}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Any, ::Any, ::Val) where {T, N} at ~/.julia/packages/Lux/lEqCI/src/nnlib.jl:31
  normalization(::AbstractArray{T, N}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Any, ::Any, ::Val, ::T) where {T, N} at ~/.julia/packages/Lux/lEqCI/src/nnlib.jl:31
  normalization(::AbstractArray{T, N}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Any, ::Any, ::Val, ::T, ::T) where {T, N} at ~/.julia/packages/Lux/lEqCI/src/nnlib.jl:31
Stacktrace:
 [1] (::BatchNorm{true, true, typeof(relu), typeof(Lux.zeros32), typeof(Lux.ones32), Float32})(x::Matrix{Float64}, ps::NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, st::NamedTuple{(:running_mean, :running_var, :training), Tuple{Vector{Float32}, Vector{Float32}, Val{true}}})
   @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/normalize.jl:120
 [2] macro expansion
   @ ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:0 [inlined]
 [3] applychain(layers::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(relu), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, x::Matrix{Float64}, ps::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(:running_mean, :running_var, :training), Tuple{Vector{Float32}, Vector{Float32}, Val{true}}}, NamedTuple{(), Tuple{}}}})
   @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:507
 [4] (::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(relu), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(x::Matrix{Float64}, ps::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(:running_mean, :running_var, :training), Tuple{Vector{Float32}, Vector{Float32}, Val{true}}}, NamedTuple{(), Tuple{}}}})
   @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:504
 [5] top-level scope
   @ REPL[15]:1
 [6] top-level scope
   @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
(examples) pkg> st
Status `~/.julia/packages/Lux/lEqCI/examples/Project.toml`
  [c29ec348] AbstractDifferentiation v0.4.3
  [c7e460c6] ArgParse v1.1.4
  [02898b10] Augmentor v0.6.6
  [052768ef] CUDA v3.12.0
⌅ [b0b7db55] ComponentArrays v0.11.17
  [2e981812] DataLoaders v0.1.3
  [41bf760c] DiffEqSensitivity v6.79.0
  [587475ba] Flux v0.13.4
⌅ [acf642fa] FluxMPI v0.5.3
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.30
⌅ [d9f16b24] Functors v0.2.8
  [6218d12a] ImageMagick v1.2.2
⌃ [916415d5] Images v0.24.1
  [b835a17e] JpegTurbo v0.1.1
  [b2108857] Lux v0.4.9
  [cc2ba9b6] MLDataUtils v0.5.4
  [eb30cadb] MLDatasets v0.7.4
  [f1d291b0] MLUtils v0.2.9
  [dbeba491] Metalhead v0.7.3
  [872c559c] NNlib v0.8.8
  [3bd65402] Optimisers v0.2.7
  [1dea7af3] OrdinaryDiffEq v6.18.2
  [d7d3b36b] ParameterSchedulers v0.3.3
  [91a5bcdd] Plots v1.31.3
  [37e2e3b7] ReverseDiff v1.14.1
⌅ [efcf1570] Setfield v0.8.2
  [fce5fe82] Turing v0.21.9
  [e88e6eb3] Zygote v0.6.41
  [de0858da] Printf
  [9a3f8284] Random
  [10745b16] Statistics
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ cannot be upgraded. To see why use `status --outdated`
julia> VERSION
v"1.8.0-rc3"
@avik-pal
Copy link
Member

It should be Float32. I think relaxing the type constraints should just work, but someone needs to put time and check it (and add some type stability tests).

@avik-pal avik-pal added the documentation Improvements or additions to documentation label Jul 21, 2022
@avik-pal
Copy link
Member

Fixed in #97. Tracking the normalization issue in #98.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants