You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Lux version of the first exemplre from in Migrating from Flux to Lux from the docs is broken. The reason is that the input x is a Matrix{Float64}, but when setting up the model ps, st = Lux.setup(rng, model) the parameters are Matrix{Float32} by default, and normalization function requires the eltypes of all their arguments to be the same. A quick fix for the example in the docs to work is just to initialize x as with Float32s: x = randn(rng, Float32, 2, 4). However I think it would be good the fix the normalization issue eventually. Also, it would be nice for keeping that line of code between Flux and Lux the same. Or we could make x = randn(rng, Float32, 2, 4) for both Lux and Flux.
For fixing the normalization issue, an some options could be to add a method with a different parametrized type for the first argument (which is x) or try to promote/convert some of the types if they are not the same. What do you think it would be the best way to handle this?
Code and error message
using Lux, Random, NNlib, Zygote
model =Chain(Dense(2=>4), BatchNorm(4, relu), Dense(4=>2))
rng = Random.default_rng()
x =randn(rng, 2, 4)
ps, st = Lux.setup(rng, model)
model(x, ps, st)
It should be Float32. I think relaxing the type constraints should just work, but someone needs to put time and check it (and add some type stability tests).
The Lux version of the first exemplre from in
Migrating from Flux to Lux
from the docs is broken. The reason is that the inputx
is aMatrix{Float64}
, but when setting up the modelps, st = Lux.setup(rng, model)
the parameters areMatrix{Float32}
by default, andnormalization
function requires the eltypes of all their arguments to be the same. A quick fix for the example in the docs to work is just to initializex
as withFloat32
s:x = randn(rng, Float32, 2, 4)
. However I think it would be good the fix the normalization issue eventually. Also, it would be nice for keeping that line of code between Flux and Lux the same. Or we could makex = randn(rng, Float32, 2, 4)
for both Lux and Flux.For fixing the normalization issue, an some options could be to add a method with a different parametrized type for the first argument (which is
x
) or try to promote/convert some of the types if they are not the same. What do you think it would be the best way to handle this?Code and error message
The text was updated successfully, but these errors were encountered: