-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Restrict train!
to AbstractOptimiser
#1902
Conversation
src/optimise/train.jl
Outdated
@@ -101,7 +93,7 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop. | |||
|
|||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. | |||
""" | |||
function train!(loss, ps, data, opt; cb = () -> ()) | |||
function train!(loss, ps, data, opt::AbstractOptimiser; cb = () -> ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further restricting to ps::Params
gives errors:
Training Loop: Error During Test at /Users/me/.julia/dev/Flux/test/optimise.jl:48
Got exception outside of a @test
MethodError: no method matching train!(::var"#57#63", ::Tuple{}, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::Descent)
Closest candidates are:
train!(::Any, ::Params, ::Any, ::Flux.Optimise.AbstractOptimiser; cb) at ~/.julia/dev/Flux/src/optimise/train.jl:105
Stacktrace:
...
DataLoader: Error During Test at /Users/me/.julia/dev/Flux/test/data.jl:3
Got exception outside of a @test
MethodError: no method matching train!(::var"#loss#93", ::Vector{Vector{Float64}}, ::IterTools.NCycle{DataLoader{Matrix{Float64}, Random._GLOBAL_RNG}}, ::Descent)
Closest candidates are:
train!(::Any, ::Params, ::Any, ::Flux.Optimise.AbstractOptimiser; cb) at ~/.julia/dev/Flux/src/optimise/train.jl:105
Stacktrace:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the blame it's hard to say how intentional the use of ()
in https://github.com/FluxML/Flux.jl/blob/master/test/optimise.jl#L51-L74 was, but I'd be fine with updating the tests to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Thinking aloud a bit, the idea is that all of these can then co-exist:
train!(loss, ps::Params, data, opt::AbstractOptimiser)
as beforetrain!(loss, ps::Params, data, opt)
would use implicit parameters + Optimisers.jl optimisers. I haven't written it but I think it wouldn't be so hard to store their states in an IdDict. That would let us delete Flux's optimisers completely, and many use cases would not need to change anything. Still callsgradient(() -> loss(x, y), ps)
.train!(loss, model, data, opt)
would callgradient(m -> loss(m, x, y), model)
and be fully structural. Only works with new optimisers of course.
Maybe that's too many steps and clearer to just do the first and last. But the more restrictive we are in the initial 0.13.0, the more room we have to decide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Optimisers.jl method could look like this:
function train!(loss, model, data, opt; cb = [])
opt isa AbstractOptimiser && error("old-style AbstractOptimiser can only be used with Params")
model isa Params && error("implicit Params can only be used with old-style AbstractOptimiser")
cb = runall(cb)
state = Optimisers.setup(opt, model)
@withprogress for (i, d) in enumerate(data)
try
grad = gradient(model) do m
loss(m, batchmemaybe(d)...)
end
state, model = update!(state, model, grad...)
cb()
catch ex
# ...
end
@logprogress i / length_or_zero(data)
end
model
end
The downside of this is that state
is thrown away, so @epochs 100 train!(...)
will reset each time. Whereas the present method mutates opt
and stores it there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm then we'd lose momentum info, etc. over epochs. What about the interface being train!(loss, model, data, opt, state = Optimisers.setup(opt, model))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be an example of an immutable rule? I'm trying to work out why the wrapper would be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rules from Optimisers.jl are all immutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rules are, but the state they're manipulating shouldn't be? Unless you were considering not returning the optimizer state from train!
and keeping it in the wrapper instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly. The fluxoptwrap
idea is to let you use train!
mutating both model and optimiser state, with the Optimisers.jl stuff hidden away. Changing existing code to use this would be a much smaller change than using Optimisers directly.
I'm not certain that's a good idea, and it isn't in this PR. But it's one path we could take.
but the state they're manipulating shouldn't be?
By this I think you mean that the momentum arrays etc. in Optimisers.jl's state tree are in fact mutable. But Optimisers.jl certainly expects that other immutable things can be passed to the next iteration through its state, e.g. ADAM:
https://github.com/FluxML/Optimisers.jl/blob/master/src/rules.jl#L132
So even for a completely mutable model, you really have to pass the state tree around somehow.
train!
could return it, of course, but then it starts to be a very different function, see variants earlier in this thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not worry about the Optimisers.jl integration for now then.
Update the array `x` according to `x .-= x̄`. | ||
""" | ||
function update!(x::AbstractArray, x̄) | ||
x .-= x̄ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted by #1860, this function seems never to be used.
Maybe someone else uses it, and it deserves a depwarn, or something? But it's trivial, most people would write their own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is documented, so a depwarn unless we're considering this a breaking change for 0.13.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see that, thanks.
This whole PR is certainly breaking. Do you think it's worth bothering with a deprecation for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be a reasonably small change and consistent with the other 0.13 deprecations we've added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you wouldn't mind adding this, everything else LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, done.
@@ -521,7 +521,7 @@ opt = AdaBelief() | |||
opt = AdaBelief(0.001, (0.9, 0.8)) | |||
``` | |||
""" | |||
mutable struct AdaBelief | |||
mutable struct AdaBelief <: AbstractOptimiser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume this was an oversight, all others have supertype.
8e2dcab
to
e38c155
Compare
Codecov Report
@@ Coverage Diff @@
## master #1902 +/- ##
==========================================
- Coverage 86.64% 86.57% -0.07%
==========================================
Files 18 18
Lines 1445 1445
==========================================
- Hits 1252 1251 -1
- Misses 193 194 +1
Continue to review full report at Codecov.
|
Optimiser(o...) = Optimiser(Any[o...]) | ||
Optimiser(opts::AbstractOptimiser...) = Optimiser(Any[opts...]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This chain thing calls apply!(opt, x, Δ)
on each element, which only accepts the various AbstractOptimiser types. So it should be safe to restrict --- and doing so may avoid some mistakes with Optimisers.OptimiserChain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the splat still necessary then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't think about it. The type specifies Vector{Any}. It forwards Base.push!, Base.setindex!
. I could try removing these to store a tuple, see if anything breaks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I don't think it's worth the hassle. With any luck we'll be removing it soon.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. | ||
Multiple callbacks can be passed to `cb` as array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seem to be no tests of this "Multiple optimisers ...opt
... as array". Presumably you should really use Optimiser(o...)
for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, the backing code has long since disappeared.
Maybe this is ready to go? |
Flux's optimisers all (except 1 mistake?) have a supertype. So I think we can restrict
train!
and friends to demand this.The reason to do so is that I think it makes it possible to add a
train!
method which accepts the ones from Optimisers.jl (which have no supertype), later. And have them coexist within one version of Flux.I guess this counts as breaking, but are there many other optimisers without the supertype in the wild? Will it break anything else?