Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iterator interface #745

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions src/multivariate/optimize/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,85 +53,91 @@ promote_objtype(method::ZerothOrderOptimizer, x, autodiff::Symbol, inplace::Bool
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td

for optimize in [:optimize, :optimizing]
@eval begin

# if no method or options are present
function optimize(f, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
function $optimize(f, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
method = fallback_method(f)
checked_kwargs, method = check_kwargs(kwargs, method)
d = promote_objtype(method, initial_x, autodiff, inplace, f)
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end
function optimize(f, g, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
function $optimize(f, g, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)

method = fallback_method(f, g)
checked_kwargs, method = check_kwargs(kwargs, method)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end
function optimize(f, g, h, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
function $optimize(f, g, h, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)

method = fallback_method(f, g, h)
checked_kwargs, method = check_kwargs(kwargs, method)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end

# no method supplied with objective
function optimize(d::T, initial_x::AbstractArray, options::Options) where T<:AbstractObjective
optimize(d, initial_x, fallback_method(d), options)
function $optimize(d::T, initial_x::AbstractArray, options::Options) where T<:AbstractObjective
$optimize(d, initial_x, fallback_method(d), options)
end
# no method supplied with inplace and autodiff keywords becauase objective is not supplied
function optimize(f, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
function $optimize(f, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
method = fallback_method(f)
d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end
function optimize(f, g, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
function $optimize(f, g, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)

method = fallback_method(f, g)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end
function optimize(f, g, h, initial_x::AbstractArray{T}, options::Options; inplace = true, autodiff = :finite) where {T}
function $optimize(f, g, h, initial_x::AbstractArray{T}, options::Options; inplace = true, autodiff = :finite) where {T}

method = fallback_method(f, g, h)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)

optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end

# potentially everything is supplied (besides caches)
function optimize(f, initial_x::AbstractArray, method::AbstractOptimizer,
function $optimize(f, initial_x::AbstractArray, method::AbstractOptimizer,
options::Options = Options(;default_options(method)...); inplace = true, autodiff = :finite)

d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end
function optimize(f, g, initial_x::AbstractArray, method::AbstractOptimizer,
function $optimize(f, g, initial_x::AbstractArray, method::AbstractOptimizer,
options::Options = Options(;default_options(method)...); inplace = true, autodiff = :finite)

d = promote_objtype(method, initial_x, autodiff, inplace, f, g)

optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end
function optimize(f, g, h, initial_x::AbstractArray{T}, method::AbstractOptimizer,
function $optimize(f, g, h, initial_x::AbstractArray{T}, method::AbstractOptimizer,
options::Options = Options(;default_options(method)...); inplace = true, autodiff = :finite) where T

d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)

optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end

function optimize(d::D, initial_x::AbstractArray, method::SecondOrderOptimizer,
function $optimize(d::D, initial_x::AbstractArray, method::SecondOrderOptimizer,
options::Options = Options(;default_options(method)...); autodiff = :finite, inplace = true) where {D <: Union{NonDifferentiable, OnceDifferentiable}}
d = promote_objtype(method, initial_x, autodiff, inplace, d)
optimize(d, initial_x, method, options)
$optimize(d, initial_x, method, options)
end

end # eval
end # for
131 changes: 109 additions & 22 deletions src/multivariate/optimize/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,73 @@ function initial_convergence(d, state, method::AbstractOptimizer, initial_x, opt
end
initial_convergence(d, state, method::ZerothOrderOptimizer, initial_x, options) = false

function optimize(d::D, initial_x::Tx, method::M,
options::Options{T, TCallback} = Options(;default_options(method)...),
state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray, T, TCallback}
if length(initial_x) == 1 && typeof(method) <: NelderMead
error("You cannot use NelderMead for univariate problems. Alternatively, use either interval bound univariate optimization, or another method such as BFGS or Newton.")
end
struct OptimIterator{D <: AbstractObjective, M <: AbstractOptimizer, Tx <: AbstractArray, O <: Options, S}
d::D
initial_x::Tx
method::M
options::O
state::S
end

Base.IteratorSize(::Type{<:OptimIterator}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:OptimIterator}) = Base.HasEltype()
Base.eltype(::Type{<:OptimIterator}) = IteratorState

@with_kw struct IteratorState{IT <: OptimIterator, TR <: OptimizationTrace}
# Put `OptimIterator` in iterator state so that `OptimizationResults` can
# be constructed from `IteratorState`.
iter::IT
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can drop iter field from IteratorState if we change the API to:

let istate
    iter = optimizing(args...; kwargs...)
    for istate′ in iter
        istate = istate′
    end
    OptimizationResults(iter, istate)  # need to pass `iter` here
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessor functions like iteration_limit_reached(istate) and (f|g|h)_calls(istate) need .iter, too.


t0::Float64
tr::TR
tracing::Bool
stopped::Bool
stopped_by_callback::Bool
stopped_by_time_limit::Bool
f_limit_reached::Bool
g_limit_reached::Bool
h_limit_reached::Bool
x_converged::Bool
f_converged::Bool
f_increased::Bool
counter_f_tol::Int
g_converged::Bool
converged::Bool
iteration::Int
ls_success::Bool
end

function Base.iterate(iter::OptimIterator, istate = nothing)
@unpack d, initial_x, method, options, state = iter
if istate === nothing
t0 = time() # Initial time stamp used to control early stopping by options.time_limit

tr = OptimizationTrace{typeof(value(d)), typeof(method)}()
tracing = options.store_trace || options.show_trace || options.extended_trace || options.callback != nothing
stopped, stopped_by_callback, stopped_by_time_limit = false, false, false
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0

t0 = time() # Initial time stamp used to control early stopping by options.time_limit
g_converged = initial_convergence(d, state, method, initial_x, options)
converged = g_converged

tr = OptimizationTrace{typeof(value(d)), typeof(method)}()
tracing = options.store_trace || options.show_trace || options.extended_trace || options.callback != nothing
stopped, stopped_by_callback, stopped_by_time_limit = false, false, false
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0
# prepare iteration counter (used to make "initial state" trace entry)
iteration = 0

g_converged = initial_convergence(d, state, method, initial_x, options)
converged = g_converged
options.show_trace && print_header(method)
trace!(tr, d, state, iteration, method, options, time()-t0)
ls_success::Bool = true
else
@unpack_IteratorState istate

# prepare iteration counter (used to make "initial state" trace entry)
iteration = 0
!converged && !stopped && iteration < options.iterations || return nothing

options.show_trace && print_header(method)
trace!(tr, d, state, iteration, method, options, time()-t0)
ls_success::Bool = true
while !converged && !stopped && iteration < options.iterations
iteration += 1

ls_failed = update_state!(d, state, method)
if !ls_success
break # it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors)
# it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors)
return nothing
end
update_g!(d, state, method) # TODO: Should this be `update_fg!`?

Expand Down Expand Up @@ -85,7 +122,35 @@ function optimize(d::D, initial_x::Tx, method::M,
stopped_by_time_limit || f_limit_reached || g_limit_reached || h_limit_reached
stopped = true
end
end # while
end

new_istate = IteratorState(
iter,
t0,
tr,
tracing,
stopped,
stopped_by_callback,
stopped_by_time_limit,
f_limit_reached,
g_limit_reached,
h_limit_reached,
x_converged,
f_converged,
f_increased,
counter_f_tol,
g_converged,
converged,
iteration,
ls_success,
)

return new_istate, new_istate
end

function OptimizationResults(istate::IteratorState)
@unpack_IteratorState istate
@unpack d, initial_x, method, options, state = iter

after_while!(d, state, method, options)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing a mutating-function after_while! inside non-bang function OptimizationResults is not super great. But it looks like after_while! is a no-op mostly except for NelderMead so maybe it's OK? From a quick look, after_while! for NelderMead seems to be idempotent (so that, e.g., calling OptimizationResults(istate) inside a loop multiple times is OK). If that's the case, OptimizationResults(istate) practically has no side-effect?

But using a function like result! sounds good to me as well.


Expand All @@ -94,6 +159,9 @@ function optimize(d::D, initial_x::Tx, method::M,
Tf = typeof(value(d))
f_incr_pick = f_increased && !options.allow_f_increases

T = (_tmp(::Options{T}) where T = T)(options)
Tx = typeof(initial_x)

return MultivariateOptimizationResults{typeof(method),T,Tx,typeof(x_abschange(state)),Tf,typeof(tr), Bool}(method,
initial_x,
pick_best_x(f_incr_pick, state),
Expand All @@ -120,3 +188,22 @@ function optimize(d::D, initial_x::Tx, method::M,
h_calls(d),
!ls_success)
end

function optimizing(d::D, initial_x::Tx, method::M,
options::Options = Options(;default_options(method)...),
state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray}
if length(initial_x) == 1 && typeof(method) <: NelderMead
error("You cannot use NelderMead for univariate problems. Alternatively, use either interval bound univariate optimization, or another method such as BFGS or Newton.")
end
return OptimIterator(d, initial_x, method, options, state)
end

function optimize(d::D, initial_x::Tx, method::M,
options::Options = Options(;default_options(method)...),
state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray}
local istate
for istate′ in optimizing(d, initial_x, method, options, state)
istate = istate′
end
return OptimizationResults(istate)
end
6 changes: 6 additions & 0 deletions test/general/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@
res_extended_nm = Optim.optimize(f, g!, initial_x, NelderMead(), options_extended_nm)
@test haskey(Optim.trace(res_extended_nm)[1].metadata,"centroid")
@test haskey(Optim.trace(res_extended_nm)[1].metadata,"step_type")

local istate
for istate′ in Optim.optimizing(f, initial_x, BFGS())
istate = istate′
end
@test Optim.OptimizationResults(istate) isa Optim.MultivariateOptimizationResults
end

# Test univariate API
Expand Down