-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add struct metadata to cache while de/restructuring #1353
base: master
Are you sure you want to change the base?
Conversation
@avik-pal if this works well enough, can you suggest a good test? As I understand it, its very similar to what we have going on already. I'd like to understand the case this is solving before merging. |
SciML/DiffEqFlux.jl#391 (comment) might be a good test. Essentially have 2 copies of an RNN. Run the first one for t timesteps, record the outputs. For the 2nd one, re/destructure everytime before calling it with the same inputs. The 2 outputs must match. |
Hi! I tried to test this with something similar to SciML/DiffEqFlux.jl#391 (comment) and found that there are some issues. using Flux
using Random
Random.seed!(1)
ANN = RNN(1, 1);
par, func = Flux.destructure(ANN);
ANN.state == func(par).state # true
ANN([1f0])
ANN.state == func(par).state # Raises the following error ERROR: LoadError: BoundsError: attempt to access 4-element Vector{Float32} at index [5:5] which points to the line x = reshape(xs[i.+(1:length(x))], size(x)) from the modified Adding some import Flux._restructure
function _restructure(m, xs; cache = IdDict())
i = 0
m̄ = fmap(m) do x
@show x
x isa AbstractArray || @show cache[x]
x isa AbstractArray || return cache[x]
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return m̄
end
Random.seed!(1)
ANN = RNN(1, 1);
par, func = Flux.destructure(ANN);
ANN.state == func(par).state # true
ANN([1f0])
ANN.state # 1×1 Matrix{Float32}: 0.16474228
ANN.state == func(par).state # Raises the following error Output: x = tanh
cache[x] = tanh
x = Float32[0.1662574]
x = Float32[-0.23294435]
x = Float32[0.0]
x = Float32[0.0]
x = tanh
cache[x] = tanh
x = Float32[0.1662574]
x = Float32[-0.23294435]
x = Float32[0.0]
x = Float32[0.0]
x = Float32[0.16474228]
ERROR: LoadError: BoundsError: attempt to access 4-element Vector{Float32} at index [5:5] I don't quite understand how the |
In any case, I'm also trying to understand the motivation of this modification. For example, having something like m = Chain(RNN(10, 5), Dense(5, 3))
θ, re = Flux.destructure(m);
x = [rand(Float32, 5, 100) for i in 1:10]
function loss(p, x, y)
m = re(p)
ŷ = [m(xᵢ) for xᵢ in x]
....
end would be equivalent to doing a Flux.reset! inside the loss function for a regular model (without the destructure/restructure), right? |
However, when using the latest version of Flux from master, GalacticOptim still throws an error when trying to use a recurrent model (even in the case when resetting the state is ok). This happens only when working with recurrent models but not with other models like Dense layers, so there might be something wrong with the management of the states besides the restarting. using GalacticOptim
using Flux
using Random
Random.seed!(3)
m = Chain(RNN(10, 5), Dense(5, 3))
θ, re = Flux.destructure(m);
X = rand(Float32, 10, 5, 1000)
Y = rand(Float32, 3, 1000)
data = Flux.DataLoader((X, Y), batchsize=10)
function loss(p, x, y)
x_unstacked = Flux.unstack(x, 2)
m = re(p)
ŷ = [m(x_i) for x_i in x_unstacked][end]
sum((y .- ŷ).^2)
end
cb = function (p,l)
display(l)
return false
end
optfun = OptimizationFunction((θ, p, x, y) -> loss(θ, x, y), GalacticOptim.AutoZygote())
optprob = OptimizationProblem(optfun, θ)
res = GalacticOptim.solve(optprob, opt, data, cb = cb) Error: ┌ Warning: Expected 103 params, got 18
└ @ Flux ~/.julia/packages/Flux/Zz9RI/src/utils.jl:623
ERROR: LoadError: DimensionMismatch("array could not be broadcast to match destination")
Stacktrace:
[1] check_broadcast_shape
@ ./broadcast.jl:520 [inlined]
[2] check_broadcast_axes
@ ./broadcast.jl:523 [inlined]
[3] instantiate
@ ./broadcast.jl:269 [inlined]
[4] materialize!
@ ./broadcast.jl:894 [inlined]
[5] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(identity), Tuple{Vector{Float32}}})
@ Base.Broadcast ./broadcast.jl:891
[6] (::GalacticOptim.var"#135#145"{GalacticOptim.var"#134#144"{OptimizationFunction{true, GalacticOptim.AutoZygote, var"#7#8", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::Vector{Float32}, ::Vector{Float32}, ::Array{Float32, 3}, ::Vararg{Any, N} where N)
@ GalacticOptim ~/.julia/packages/GalacticOptim/bEh06/src/function/zygote.jl:8
[7] macro expansion
@ ~/.julia/packages/GalacticOptim/bEh06/src/solve/flux.jl:43 [inlined]
[8] macro expansion
@ ~/.julia/packages/GalacticOptim/bEh06/src/solve/solve.jl:35 [inlined]
[9] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, GalacticOptim.AutoZygote, var"#7#8", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::ADAM, data::Flux.Data.DataLoader{Tuple{Array{Float32, 3}, Matrix{Float32}}, Random._GLOBAL_RNG}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ GalacticOptim ~/.julia/packages/GalacticOptim/bEh06/src/solve/flux.jl:41
[10] #solve#474
@ ~/.julia/packages/SciMLBase/UIp7W/src/solve.jl:3 [inlined]
[11] top-level scope
@ ~/Documents/julia_various_repos/testing_destructure_for_RNNs/training_Flux_layers_with_GalacticOptim.jl:52
in expression starting at /Users/ger/Documents/julia_various_repos/testing_destructure_for_RNNs/training_Flux_layers_with_GalacticOptim.jl:52 (testing_destructure_for_RNNs) pkg> st
Status `~/Documents/julia_various_repos/testing_destructure_for_RNNs/Project.toml`
[587475ba] Flux v0.12.6
[a75be94c] GalacticOptim v2.0.3
[9a3f8284] Random The error points to the adjoint of Should I keep all of this here or post it as an issue in Flux or GalacticOptim? |
@gabrevaya the error on master/stable is expected, hence why this PR exists. You'll notice the warning says "Expected 103 params, got 18", where 18 = 3x5 ( Also, I'm not really sure what you mean by:
Because this PR doesn't modify that line in |
Ohh OK, I'm so sorry. I didn't realize that this issue also causes that no parameters of the reconstructed RNN are tracked. About the line x = reshape(xs[i.+(1:length(x))], size(x)) I didn't say that it was modified, I meant that it was contained in the method of m̄ = fmap(m) do x
x isa AbstractArray || return cache[x]
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end the state of the Anyway, as I said, I don't fully understand how all this works. I was just intending to run some tests and help to identify possible bugs. |
I just wanted to clarify if you were testing Flux from this PR's branch, because the changes go beyond just that one function. There may well be other bugs that prevent RNN reconstruction from working as expected, but to troubleshoot them we need to know exactly what code you're running :) |
Hi! I wonder if anyone is still working on this. It would be great to be able to train Flux's recurrent models with GalacticOptim :) |
Yup, this is currently mimicking the behaviour on master, and needs a test case. We can bring this back up again |
Needs Tests
Ref SciML/DiffEqFlux.jl#432 (comment)
Currently, using
destructure
overwrites state for recurrent layers, which may want to hold on to it. This introduces a cache which behaves exactly like the current system but adds the references to actually restore the state.PR Checklist
@dhairyagandhi96
(for API changes).