Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADAMW not stable #1920

Closed
deahhh opened this issue Mar 24, 2022 · 9 comments
Closed

ADAMW not stable #1920

deahhh opened this issue Mar 24, 2022 · 9 comments

Comments

@deahhh
Copy link

deahhh commented Mar 24, 2022

m = Chain(
    Flux.Embedding(501102, 32),
    Flux.Conv((32,), 64=>235),
    Flux.flatten,
    Flux.softmax,
) |> gpu

function evalcb()
    ps = []
    ls = []
    for (d, l) in dl_test
        p = m(Flux.batch(map(random_select_sect, d)) |> gpu) |> cpu
        p = map(x->x[1], reshape(argmax(p, dims=1), :))
        append!(ps, p)
        append!(ls, l)
    end
    acc = sum(ps .== ls) / length(ps)
    println("accurate: $(acc)")
end

loss(x,y) =  Flux.Losses.focal_loss(m(reduce(hcat, map(random_select_sect, x)) |> gpu), Flux.onehotbatch(y, 1:235)|>gpu)

Flux.@epochs 50 Flux.Optimise.train!(loss,
                                    params(m),
                                    dl_train, 
                                    Flux.Optimise.ADAMW(0.0001,(0.9, 0.999), 0.01),
                                    cb=Flux.throttle(evalcb, 5))

Loss gets to be nan while raining on dataset wili-2018, about epoch 40, test data set accurate>0.8.
I am sure its a problem, because it runs well in pytorch.

By the way, hcat(map(x...)) is very slow!

Flux.batch(map(f, d)) is very fast, but will cause an error named Mutating arrays not supported

@ToucheSir
Copy link
Member

There's not enough here for a MWE (specifically the W part). In this case I would recommend providing the equivalent Python code as well since you're doing a head-to-head comparison.

A couple other things to note:

  1. Hyperparameter values for ADAMW are not the same between Flux and PyTorch. Happy to elaborate on this if you need more clarification
  2. That model is awfully thin, I'm surprised that a single Conv layer can learn any meaningful representation and remain stable after 40 epochs.
  3. re: > By the way, hcat(map(x...)) is very slow!

reduce(hcat, ...) should be pretty fast, so I suspect that the function you're mapping (random_select_sect) is a bottleneck.

@deahhh
Copy link
Author

deahhh commented Mar 25, 2022

Thanks for your reply!

  1. New info is Flux.Optimise.Nesterov(0.001, 0.9) made Nan too. I don't know why by now.
  2. This thin-model perform well: accurate up to 0.953 on the test dataset. And, thinner the model is, stabler the training to be, isn't it?
  3. You are right, reduce(hcat, ...) is pretty fast, as it consumes 0.4s to predict all the data in test, while 40s is taken by hcat(map(x...)), as I made experiment at the beginning.

Thanks again!
I'll check the code for another time.

@deahhh
Copy link
Author

deahhh commented Mar 25, 2022

There's not enough here for a MWE (specifically the W part). In this case I would recommend providing the equivalent Python code as well since you're doing a head-to-head comparison.

A couple other things to note:

1. Hyperparameter values for ADAMW are not the same between Flux and PyTorch. Happy to elaborate on this if you need more clarification

2. That model is awfully thin, I'm surprised that a single Conv layer can learn any meaningful representation and remain stable after 40 epochs.

3. re: > By the way, hcat(map(x...)) is very slow!

reduce(hcat, ...) should be pretty fast, so I suspect that the function you're mapping (random_select_sect) is a bottleneck.

I'm sorry, there is something different between model in pytorch and Flux: embedding in pytorch I use an argument with name max_norm. It's the stable factor I guess.

@deahhh deahhh closed this as completed Mar 25, 2022
@deahhh deahhh reopened this Mar 25, 2022
@deahhh
Copy link
Author

deahhh commented Mar 25, 2022

Yes, It is the max_norm argument acts as stable factor.
I retrain the model with larger decay argument in the optimizer, and got accurate up to 0.942 now.

What's remaining is:

  1. Why hcat(map(f, x)...) is slower than reduce(hcat, map(f, x)) so much?
  2. How can hcat(map(f, x)...) be such slow?
  3. How can Flux.batch cause error in gradient, while not in predict?

@ToucheSir
Copy link
Member

Great that you managed to get things working and see an improvement.

For 1. and 2., this has come up a few times in various help forums but the JuliaLang/julia#21672 (comment) is one of the more canonical answers. TL;DR being able to preallocate is fast, splatting tons of elements is not.
For 3, the currently released version of Flux uses this implementation which is not AD-friendly (it does array mutation). The version in master is re-exported from https://github.com/JuliaML/MLUtils.jl/, so you could try that for your use-case as well.

@mcabbott
Copy link
Member

If I'm reading this correctly, you are doing reduce(hcat, map(random_select_sect, x)) |> gpu inside the gradient call, so Zygote will compute a gradient for x here. This will be slower than doing it outside, for instance by replacing dl_train with some generator which does these steps.

@deahhh
Copy link
Author

deahhh commented Mar 25, 2022

If I'm reading this correctly, you are doing reduce(hcat, map(random_select_sect, x)) |> gpu inside the gradient call, so Zygote will compute a gradient for x here. This will be slower than doing it outside, for instance by replacing dl_train with some generator which does these steps.

Yes, my data is calculated on an ad hoc basis because of the volume of data and random selection would make it difficult to determine when to terminate.
Thanks for clearing up the confusion I had when coding, any calculation that is put into gradient will produce a gradient.

@deahhh
Copy link
Author

deahhh commented Mar 25, 2022

Great that you managed to get things working and see an improvement.

For 1. and 2., this has come up a few times in various help forums but the JuliaLang/julia#21672 (comment) is one of the more canonical answers. TL;DR being able to preallocate is fast, splatting tons of elements is not. For 3, the currently released version of Flux uses this implementation which is not AD-friendly (it does array mutation). The version in master is re-exported from https://github.com/JuliaML/MLUtils.jl/, so you could try that for your use-case as well.

Thanks!

@CarloLucibello
Copy link
Member

Closing since "ADAMW not stable" seems not to be a true issue. You can file separate issues for performance concerns if you want.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants