Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add struct metadata to cache while de/restructuring #1353

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

DhairyaLGandhi
Copy link
Member

Needs Tests

Ref SciML/DiffEqFlux.jl#432 (comment)

Currently, using destructure overwrites state for recurrent layers, which may want to hold on to it. This introduces a cache which behaves exactly like the current system but adds the references to actually restore the state.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @dhairyagandhi96 (for API changes).

@DhairyaLGandhi
Copy link
Member Author

cc @ChrisRackauckas @avik-pal

@avik-pal if this works well enough, can you suggest a good test? As I understand it, its very similar to what we have going on already. I'd like to understand the case this is solving before merging.

@avik-pal
Copy link
Member

SciML/DiffEqFlux.jl#391 (comment) might be a good test. Essentially have 2 copies of an RNN. Run the first one for t timesteps, record the outputs. For the 2nd one, re/destructure everytime before calling it with the same inputs. The 2 outputs must match.

@gabrevaya
Copy link

Hi! I tried to test this with something similar to SciML/DiffEqFlux.jl#391 (comment) and found that there are some issues.

using Flux
using Random

Random.seed!(1)
ANN  = RNN(1, 1);
par, func = Flux.destructure(ANN);
ANN.state == func(par).state # true
ANN([1f0])
ANN.state == func(par).state # Raises the following error
ERROR: LoadError: BoundsError: attempt to access 4-element Vector{Float32} at index [5:5]

which points to the line

 x = reshape(xs[i.+(1:length(x))], size(x))

from the modified _restructure method.

Adding some @shows to _restructure we get the following:

import Flux._restructure
function _restructure(m, xs; cache = IdDict())
    i = 0= fmap(m) do x
      @show x
      x isa AbstractArray || @show cache[x]
      x isa AbstractArray || return cache[x]
      x = reshape(xs[i.+(1:length(x))], size(x))
      i += length(x)
      return x
    end
    length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
    returnend

Random.seed!(1)
ANN  = RNN(1, 1);
par, func = Flux.destructure(ANN);
ANN.state == func(par).state # true
ANN([1f0])
ANN.state # 1×1 Matrix{Float32}: 0.16474228
ANN.state == func(par).state # Raises the following error

Output:

x = tanh
cache[x] = tanh
x = Float32[0.1662574]
x = Float32[-0.23294435]
x = Float32[0.0]
x = Float32[0.0]
x = tanh
cache[x] = tanh
x = Float32[0.1662574]
x = Float32[-0.23294435]
x = Float32[0.0]
x = Float32[0.0]
x = Float32[0.16474228]
ERROR: LoadError: BoundsError: attempt to access 4-element Vector{Float32} at index [5:5]

I don't quite understand how the IdDict cache works but I notice that it is adding the state to the model as an AbstractArray and _restructure treat it similarly to the other parameters.

@gabrevaya
Copy link

In any case, I'm also trying to understand the motivation of this modification.
In general, don't you reset the state of a recurrent model after updating the parameters of the model or there are some cases where you would like to change the parameters and keep the state?

For example, having something like

m = Chain(RNN(10, 5), Dense(5, 3))
θ, re = Flux.destructure(m);
x = [rand(Float32, 5, 100) for i in 1:10]


function loss(p, x, y)
    m = re(p)
    ŷ = [m(xᵢ) for xᵢ in x]
    ....

end

would be equivalent to doing a Flux.reset! inside the loss function for a regular model (without the destructure/restructure), right?

@gabrevaya
Copy link

gabrevaya commented Sep 2, 2021

However, when using the latest version of Flux from master, GalacticOptim still throws an error when trying to use a recurrent model (even in the case when resetting the state is ok). This happens only when working with recurrent models but not with other models like Dense layers, so there might be something wrong with the management of the states besides the restarting.

using GalacticOptim
using Flux
using Random

Random.seed!(3)
m = Chain(RNN(10, 5), Dense(5, 3))
θ, re = Flux.destructure(m);

X = rand(Float32, 10, 5, 1000)
Y = rand(Float32, 3, 1000)
data = Flux.DataLoader((X, Y), batchsize=10)

function loss(p, x, y)
    x_unstacked = Flux.unstack(x, 2)
    m = re(p)
    ŷ = [m(x_i) for x_i in x_unstacked][end]
    sum((y .- ŷ).^2)
end

cb = function (p,l)
    display(l)
    return false
end

optfun = OptimizationFunction((θ, p, x, y) -> loss(θ, x, y), GalacticOptim.AutoZygote())
optprob = OptimizationProblem(optfun, θ)
res = GalacticOptim.solve(optprob, opt, data, cb = cb)

Error:

┌ Warning: Expected 103 params, got 18
└ @ Flux ~/.julia/packages/Flux/Zz9RI/src/utils.jl:623
ERROR: LoadError: DimensionMismatch("array could not be broadcast to match destination")
Stacktrace:
  [1] check_broadcast_shape
    @ ./broadcast.jl:520 [inlined]
  [2] check_broadcast_axes
    @ ./broadcast.jl:523 [inlined]
  [3] instantiate
    @ ./broadcast.jl:269 [inlined]
  [4] materialize!
    @ ./broadcast.jl:894 [inlined]
  [5] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(identity), Tuple{Vector{Float32}}})
    @ Base.Broadcast ./broadcast.jl:891
  [6] (::GalacticOptim.var"#135#145"{GalacticOptim.var"#134#144"{OptimizationFunction{true, GalacticOptim.AutoZygote, var"#7#8", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::Vector{Float32}, ::Vector{Float32}, ::Array{Float32, 3}, ::Vararg{Any, N} where N)
    @ GalacticOptim ~/.julia/packages/GalacticOptim/bEh06/src/function/zygote.jl:8
  [7] macro expansion
    @ ~/.julia/packages/GalacticOptim/bEh06/src/solve/flux.jl:43 [inlined]
  [8] macro expansion
    @ ~/.julia/packages/GalacticOptim/bEh06/src/solve/solve.jl:35 [inlined]
  [9] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, GalacticOptim.AutoZygote, var"#7#8", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::ADAM, data::Flux.Data.DataLoader{Tuple{Array{Float32, 3}, Matrix{Float32}}, Random._GLOBAL_RNG}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/bEh06/src/solve/flux.jl:41
 [10] #solve#474
    @ ~/.julia/packages/SciMLBase/UIp7W/src/solve.jl:3 [inlined]
 [11] top-level scope
    @ ~/Documents/julia_various_repos/testing_destructure_for_RNNs/training_Flux_layers_with_GalacticOptim.jl:52
in expression starting at /Users/ger/Documents/julia_various_repos/testing_destructure_for_RNNs/training_Flux_layers_with_GalacticOptim.jl:52
(testing_destructure_for_RNNs) pkg> st
      Status `~/Documents/julia_various_repos/testing_destructure_for_RNNs/Project.toml`
  [587475ba] Flux v0.12.6
  [a75be94c] GalacticOptim v2.0.3
  [9a3f8284] Random

The error points to the adjoint of _restructure so I don't think it is related to GalacticOptim, but just in case @ChrisRackauckas , am I setting up the GalacticOptim problem correctly?

Should I keep all of this here or post it as an issue in Flux or GalacticOptim?

@ToucheSir
Copy link
Member

@gabrevaya the error on master/stable is expected, hence why this PR exists. You'll notice the warning says "Expected 103 params, got 18", where 18 = 3x5 (Dense weight) + 3 (Dense bias). The purpose of this PR is to fix tracking Recur in destructure so that the other 85 params from the RNN layer are picked up.

Also, I'm not really sure what you mean by:

Hi! I tried to test this with something similar to SciML/DiffEqFlux.jl#391 (comment) and found that there are some issues.

using Flux
using Random

Random.seed!(1)
ANN  = RNN(1, 1);
par, func = Flux.destructure(ANN);
ANN.state == func(par).state # true
ANN([1f0])
ANN.state == func(par).state # Raises the following error
ERROR: LoadError: BoundsError: attempt to access 4-element Vector{Float32} at index [5:5]

which points to the line

 x = reshape(xs[i.+(1:length(x))], size(x))

from the modified _restructure method.

Because this PR doesn't modify that line in _restructure at all.

@gabrevaya
Copy link

Ohh OK, I'm so sorry. I didn't realize that this issue also causes that no parameters of the reconstructed RNN are tracked.

About the line

x = reshape(xs[i.+(1:length(x))], size(x))

I didn't say that it was modified, I meant that it was contained in the method of _restructure that is being modified in this PR. I was just trying to help you to identify the issue, which currently is above my Julia level. I just notice that in the PR version, in _restructure, in

= fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i.+(1:length(x))], size(x))
    i += length(x)
    return x
end

the state of the Recur appears during the fmap and breaks the indexing since xs includes all the parameters except for the state.

Anyway, as I said, I don't fully understand how all this works. I was just intending to run some tests and help to identify possible bugs.

@ToucheSir
Copy link
Member

I just wanted to clarify if you were testing Flux from this PR's branch, because the changes go beyond just that one function. There may well be other bugs that prevent RNN reconstruction from working as expected, but to troubleshoot them we need to know exactly what code you're running :)

@gabrevaya
Copy link

Hi! I wonder if anyone is still working on this. It would be great to be able to train Flux's recurrent models with GalacticOptim :)

@DhairyaLGandhi
Copy link
Member Author

Yup, this is currently mimicking the behaviour on master, and needs a test case. We can bring this back up again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants