Skip to content

Commit

Permalink
Standardize the Extension Algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2023
1 parent fa16796 commit 73005ca
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.3.0"
version = "3.3.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/basics/solve.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ solve(prob::SciMLBase.NonlinearProblem, args...; kwargs...)
## Iteration Controls

- `maxiters::Int`: The maximum number of iterations to perform. Defaults to `1000`.
- `abstol::Number`: The absolute tolerance.
- `reltol::Number`: The relative tolerance.
- `abstol::Number`: The absolute tolerance. Defaults to `real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
- `reltol::Number`: The relative tolerance. Defaults to `real(oneunit(T)) * (eps(real(one(T))))^(2 // 5)`.
- `termination_condition`: Termination Condition from DiffEqBase. Defaults to
`AbsSafeBestTerminationMode()` for `NonlinearSolve.jl` and `AbsTerminateMode()` for
`SimpleNonlinearSolve.jl`.
Expand Down
15 changes: 9 additions & 6 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ConcreteStructs: @concrete
import FastLevenbergMarquardt as FastLM
import FiniteDiff, ForwardDiff

function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
if linsolve === :cholesky
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve === :qr
Expand Down Expand Up @@ -33,14 +33,17 @@ end
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = 1e-8,
reltol = 1e-8, maxiters = 1000, kwargs...)
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
reltol = nothing, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
fu = NonlinearSolve.evaluate_f(prob, u)

f! = InplaceFunction{iip}(prob.f)

abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))

if prob.f.jac === nothing
use_forward_diff = if alg.autodiff === nothing
ForwardDiff.can_dual(eltype(u))
Expand Down Expand Up @@ -95,9 +98,9 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
LM = FastLM.LMWorkspace(u, fu, J)

return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
alg.maxfactor, kwargs...))
(; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, alg.factor,
alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale,
alg.factorupdate, alg.minfactor, alg.maxfactor))
end

function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
Expand Down
2 changes: 1 addition & 1 deletion ext/NonlinearSolveFixedPointAccelerationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u)
end

tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))

sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm,
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
Expand Down
19 changes: 11 additions & 8 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
@inline function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
ls = linsolve === :qr ? LSO.QR() :
(linsolve === :cholesky ? LSO.Cholesky() :
(linsolve === :lsmr ? LSO.LSMR() : nothing))
Expand Down Expand Up @@ -33,25 +33,28 @@ end
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
args...; alias_u0 = false, abstol = 1e-8, reltol = 1e-8, verbose = false,
maxiters = 1000, kwargs...)
args...; alias_u0 = false, abstol = nothing, show_trace::Val{ShT} = Val(false),
trace_level = TraceMinimal(), store_trace::Val{StT} = Val(false), maxiters = 1000,
reltol = nothing, kwargs...) where {ShT, StT}
iip = SciMLBase.isinplace(prob)
u = alias_u0 ? prob.u0 : deepcopy(prob.u0)
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)

abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))

f! = FunctionWrapper{iip}(prob.f, prob.p)
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(u, prob.p) : zeros(u)) :
prob.f.resid_prototype
(!iip ? prob.f(u, prob.p) : zeros(u)) : prob.f.resid_prototype

lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid_prototype, g!,
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

return LeastSquaresOptimJLCache(prob, alg, allocated_prob,
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
kwargs...))
(; x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters,
show_trace = ShT, store_trace = StT, show_every = trace_level.print_frequency))
end

function SciMLBase.solve!(cache::LeastSquaresOptimJLCache)
Expand Down
16 changes: 8 additions & 8 deletions ext/NonlinearSolveMINPACKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using MINPACK

function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
abstol = nothing, maxiters = 100000, alias_u0::Bool = false,
termination_condition = nothing, kwargs...) where {uType, iip}
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
show_trace::Val{ShT} = Val(false), store_trace::Val{StT} = Val(false),
termination_condition = nothing, kwargs...) where {uType, iip, ShT, StT}
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"

Expand All @@ -16,13 +17,12 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

T = eltype(u0)
sizeu = size(prob.u0)
p = prob.p

# unwrapping alg params
show_trace = alg.show_trace
tracing = alg.tracing
show_trace = alg.show_trace || ShT
tracing = alg.tracing || StT

if !iip && prob.u0 isa Number
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
Expand All @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
method = ifelse(alg.method === :auto,
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)

abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))

if SciMLBase.has_jac(prob.f)
if !iip && prob.u0 isa Number
Expand All @@ -62,10 +62,10 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
end
end
original = MINPACK.fsolve(f!, g!, vec(u0), m; tol = abstol, show_trace, tracing,
method, iterations = maxiters, kwargs...)
method, iterations = maxiters)
else
original = MINPACK.fsolve(f!, vec(u0), m; tol = abstol, show_trace, tracing,
method, iterations = maxiters, kwargs...)
method, iterations = maxiters)
end

u = reshape(original.x, size(u))
Expand Down
4 changes: 1 addition & 3 deletions ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

T = eltype(u0)
iip = isinplace(prob)

sizeu = size(prob.u0)
p = prob.p

Expand Down Expand Up @@ -70,7 +68,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
end

abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))

original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
Expand Down
3 changes: 1 addition & 2 deletions ext/NonlinearSolveSpeedMappingExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

T = eltype(u0)
iip = isinplace(prob)
p = prob.p

Expand All @@ -27,7 +26,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
m! = (du, u) -> (prob.f(du, u, p); du .+= u)
end

tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))

sol = speedmapping(u0; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders,
alg.check_obj, store_info, alg.σ_min, alg.stabilize)
Expand Down
67 changes: 53 additions & 14 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6,
end

"""
CMINPACK(; show_trace::Bool=false, tracing::Bool=false, method::Symbol=:auto)
CMINPACK(; method::Symbol = :auto)
### Keyword Arguments
- `show_trace`: whether to show the trace.
- `tracing`: who the hell knows what this does. If you find out, please open an issue/PR.
- `method`: the choice of method for the solver.
### Method Choices
Expand Down Expand Up @@ -134,27 +132,42 @@ struct CMINPACK <: AbstractNonlinearSolveAlgorithm
method::Symbol
end

function CMINPACK(; show_trace::Bool = false, tracing::Bool = false, method::Symbol = :auto)
function CMINPACK(; show_trace = missing, tracing = missing, method::Symbol = :auto)
if Base.get_extension(@__MODULE__, :NonlinearSolveMINPACKExt) === nothing
error("CMINPACK requires MINPACK.jl to be loaded")
end

if show_trace !== missing
Base.depwarn("`show_trace` for CMINPACK has been deprecated and will be removed \
in v4. Use the `show_trace` keyword argument via the logging API \
https://docs.sciml.ai/NonlinearSolve/stable/basics/Logging/ \
instead.")
else
show_trace = false
end

if tracing !== missing
Base.depwarn("`tracing` for CMINPACK has been deprecated and will be removed \
in v4. Use the `store_trace` keyword argument via the logging API \
https://docs.sciml.ai/NonlinearSolve/stable/basics/Logging/ \
instead.")
else
tracing = false
end

return CMINPACK(show_trace, tracing, method)
end

"""
NLsolveJL(; method=:trust_region, autodiff=:central, store_trace=false,
extended_trace=false, linesearch=LineSearches.Static(),
linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true,
m=10, beta=one(Float64), show_trace=false)
NLsolveJL(; method = :trust_region, autodiff = :central, linesearch = Static(),
linsolve = (x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale = true,
m = 10, beta = one(Float64))
### Keyword Arguments
- `method`: the choice of method for solving the nonlinear system.
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or
central differencing via FiniteDiff.jl. The other choices are `:forward`
- `show_trace`: should a trace of the optimization algorithm's state be shown on `STDOUT`?
- `extended_trace`: should additional algorithm internals be added to the state trace?
- `linesearch`: the line search method to be used within the solver method. The choices
are line search types from
[LineSearches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl).
Expand All @@ -168,7 +181,6 @@ end
constants are close to 1. If convergence fails, though, you may consider lowering it.
- `beta`: It is also known as DIIS or Pulay mixing, this method is based on the acceleration
of the fixed-point iteration xₙ₊₁ = xₙ + beta*f(xₙ), where by default beta = 1.
- `store_trace``: should a trace of the optimization algorithm's state be stored?
### Submethod Choice
Expand All @@ -195,14 +207,41 @@ Choices for methods in `NLsolveJL`:
show_trace::Bool
end

function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = false,
extended_trace = false, linesearch = LineSearches.Static(),
function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = missing,
extended_trace = missing, linesearch = LineSearches.Static(),
linsolve = (x, A, b) -> copyto!(x, A \ b), factor = 1.0, autoscale = true, m = 10,
beta = one(Float64), show_trace = false)
beta = one(Float64), show_trace = missing)
if Base.get_extension(@__MODULE__, :NonlinearSolveNLsolveExt) === nothing
error("NLsolveJL requires NLsolve.jl to be loaded")
end

if show_trace !== missing
Base.depwarn("`show_trace` for NLsolveJL has been deprecated and will be removed \
in v4. Use the `show_trace` keyword argument via the logging API \
https://docs.sciml.ai/NonlinearSolve/stable/basics/Logging/ \
instead.")
else
show_trace = false
end

if store_trace !== missing
Base.depwarn("`store_trace` for NLsolveJL has been deprecated and will be removed \
in v4. Use the `store_trace` keyword argument via the logging API \
https://docs.sciml.ai/NonlinearSolve/stable/basics/Logging/ \
instead.")
else
store_trace = false
end

if extended_trace !== missing
Base.depwarn("`extended_trace` for NLsolveJL has been deprecated and will be \
removed in v4. Use the `trace_level = TraceAll()` keyword argument \
via the logging API \
https://docs.sciml.ai/NonlinearSolve/stable/basics/Logging/ instead.")
else
extended_trace = false
end

return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
factor, autoscale, m, beta, show_trace)
end
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
const DEFAULT_NORM = DiffEqBase.NONLINEARSOLVE_DEFAULT_NORM

@inline DEFAULT_TOLERANCE(args...) = DiffEqBase._get_tolerance(args...)

@concrete mutable struct FakeLinearSolveJLCache
A
b
Expand Down

0 comments on commit 73005ca

Please sign in to comment.