Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Restrict train! to AbstractOptimiser #1902

Merged
merged 10 commits into from
Mar 20, 2022
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Mar 8, 2022

Flux's optimisers all (except 1 mistake?) have a supertype. So I think we can restrict train! and friends to demand this.

The reason to do so is that I think it makes it possible to add a train! method which accepts the ones from Optimisers.jl (which have no supertype), later. And have them coexist within one version of Flux.

I guess this counts as breaking, but are there many other optimisers without the supertype in the wild? Will it break anything else?

@@ -101,7 +93,7 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop.

Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, ps, data, opt; cb = () -> ())
function train!(loss, ps, data, opt::AbstractOptimiser; cb = () -> ())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further restricting to ps::Params gives errors:

Training Loop: Error During Test at /Users/me/.julia/dev/Flux/test/optimise.jl:48
  Got exception outside of a @test
  MethodError: no method matching train!(::var"#57#63", ::Tuple{}, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::Descent)
  Closest candidates are:
    train!(::Any, ::Params, ::Any, ::Flux.Optimise.AbstractOptimiser; cb) at ~/.julia/dev/Flux/src/optimise/train.jl:105
  Stacktrace:
...
DataLoader: Error During Test at /Users/me/.julia/dev/Flux/test/data.jl:3
  Got exception outside of a @test
  MethodError: no method matching train!(::var"#loss#93", ::Vector{Vector{Float64}}, ::IterTools.NCycle{DataLoader{Matrix{Float64}, Random._GLOBAL_RNG}}, ::Descent)
  Closest candidates are:
    train!(::Any, ::Params, ::Any, ::Flux.Optimise.AbstractOptimiser; cb) at ~/.julia/dev/Flux/src/optimise/train.jl:105
  Stacktrace:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the blame it's hard to say how intentional the use of () in https://github.com/FluxML/Flux.jl/blob/master/test/optimise.jl#L51-L74 was, but I'd be fine with updating the tests to match.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thinking aloud a bit, the idea is that all of these can then co-exist:

  • train!(loss, ps::Params, data, opt::AbstractOptimiser) as before
  • train!(loss, ps::Params, data, opt) would use implicit parameters + Optimisers.jl optimisers. I haven't written it but I think it wouldn't be so hard to store their states in an IdDict. That would let us delete Flux's optimisers completely, and many use cases would not need to change anything. Still calls gradient(() -> loss(x, y), ps).
  • train!(loss, model, data, opt) would call gradient(m -> loss(m, x, y), model) and be fully structural. Only works with new optimisers of course.

Maybe that's too many steps and clearer to just do the first and last. But the more restrictive we are in the initial 0.13.0, the more room we have to decide.

Copy link
Member Author

@mcabbott mcabbott Mar 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optimisers.jl method could look like this:

function train!(loss, model, data, opt; cb = [])
  opt isa AbstractOptimiser && error("old-style AbstractOptimiser can only be used with Params")
  model isa Params && error("implicit Params can only be used with old-style AbstractOptimiser")
  cb = runall(cb)
  state = Optimisers.setup(opt, model)
  @withprogress for (i, d) in enumerate(data)
    try
      grad = gradient(model) do m
        loss(m, batchmemaybe(d)...)
      end
      state, model = update!(state, model, grad...)
      cb()
    catch ex
      # ...
    end
    @logprogress i / length_or_zero(data)
  end
  model
end

The downside of this is that state is thrown away, so @epochs 100 train!(...) will reset each time. Whereas the present method mutates opt and stores it there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm then we'd lose momentum info, etc. over epochs. What about the interface being train!(loss, model, data, opt, state = Optimisers.setup(opt, model))?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be an example of an immutable rule? I'm trying to work out why the wrapper would be necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rules from Optimisers.jl are all immutable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rules are, but the state they're manipulating shouldn't be? Unless you were considering not returning the optimizer state from train! and keeping it in the wrapper instead.

Copy link
Member Author

@mcabbott mcabbott Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. The fluxoptwrap idea is to let you use train! mutating both model and optimiser state, with the Optimisers.jl stuff hidden away. Changing existing code to use this would be a much smaller change than using Optimisers directly.

I'm not certain that's a good idea, and it isn't in this PR. But it's one path we could take.

but the state they're manipulating shouldn't be?

By this I think you mean that the momentum arrays etc. in Optimisers.jl's state tree are in fact mutable. But Optimisers.jl certainly expects that other immutable things can be passed to the next iteration through its state, e.g. ADAM:
https://github.com/FluxML/Optimisers.jl/blob/master/src/rules.jl#L132
So even for a completely mutable model, you really have to pass the state tree around somehow.

train! could return it, of course, but then it starts to be a very different function, see variants earlier in this thread.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not worry about the Optimisers.jl integration for now then.

Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, x̄)
x .-= x̄
Copy link
Member Author

@mcabbott mcabbott Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted by #1860, this function seems never to be used.

Maybe someone else uses it, and it deserves a depwarn, or something? But it's trivial, most people would write their own.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is documented, so a depwarn unless we're considering this a breaking change for 0.13.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see that, thanks.

This whole PR is certainly breaking. Do you think it's worth bothering with a deprecation for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a reasonably small change and consistent with the other 0.13 deprecations we've added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wouldn't mind adding this, everything else LGTM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done.

@@ -521,7 +521,7 @@ opt = AdaBelief()
opt = AdaBelief(0.001, (0.9, 0.8))
```
"""
mutable struct AdaBelief
mutable struct AdaBelief <: AbstractOptimiser
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this was an oversight, all others have supertype.

@mcabbott mcabbott added this to the v0.13 milestone Mar 8, 2022
@mcabbott mcabbott force-pushed the train branch 2 times, most recently from 8e2dcab to e38c155 Compare March 8, 2022 21:29
@codecov-commenter
Copy link

codecov-commenter commented Mar 9, 2022

Codecov Report

Merging #1902 (0249386) into master (b6dbefb) will decrease coverage by 0.06%.
The diff coverage is 57.14%.

@@            Coverage Diff             @@
##           master    #1902      +/-   ##
==========================================
- Coverage   86.64%   86.57%   -0.07%     
==========================================
  Files          18       18              
  Lines        1445     1445              
==========================================
- Hits         1252     1251       -1     
- Misses        193      194       +1     
Impacted Files Coverage Δ
src/deprecations.jl 34.61% <0.00%> (-4.52%) ⬇️
src/optimise/optimisers.jl 93.64% <100.00%> (ø)
src/optimise/train.jl 93.75% <100.00%> (+5.17%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b6dbefb...0249386. Read the comment docs.

Optimiser(o...) = Optimiser(Any[o...])
Optimiser(opts::AbstractOptimiser...) = Optimiser(Any[opts...])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chain thing calls apply!(opt, x, Δ) on each element, which only accepts the various AbstractOptimiser types. So it should be safe to restrict --- and doing so may avoid some mistakes with Optimisers.OptimiserChain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the splat still necessary then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't think about it. The type specifies Vector{Any}. It forwards Base.push!, Base.setindex!. I could try removing these to store a tuple, see if anything breaks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I don't think it's worth the hassle. With any luck we'll be removing it soon.

Comment on lines -102 to +111
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
Multiple callbacks can be passed to `cb` as array.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be no tests of this "Multiple optimisers ...opt ... as array". Presumably you should really use Optimiser(o...) for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the backing code has long since disappeared.

@mcabbott
Copy link
Member Author

Maybe this is ready to go?

@mcabbott mcabbott merged commit ed78e8a into FluxML:master Mar 20, 2022
@mcabbott mcabbott deleted the train branch March 20, 2022 18:48
@mcabbott mcabbott mentioned this pull request Mar 20, 2022
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants