Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start using termination conditions from DiffEqBase #208

Merged
merged 14 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "DiffEqBase"]
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD, TC} <: AbstractNonlinearSolveAlgorithm end

abstract type AbstractNonlinearSolveCache{iip} end

Expand Down
62 changes: 49 additions & 13 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,34 @@ for large-scale and numerically-difficult nonlinear least squares problems.
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct GaussNewton{CJ, AD, TC} <: AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
termination_condition::TC
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
end

function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
abstol = nothing,
reltol = nothing), adkwargs...)
utkarsh530 marked this conversation as resolved.
Show resolved Hide resolved
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs,
termination_condition)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
fu_new
Expand All @@ -72,12 +80,15 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
stats::NLStats
tc_storage
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand All @@ -91,27 +102,46 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_with_JᵀJ = Val(true))

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

atol = _get_tolerance(abstol, tc.abstol, eltype(u))
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol,
prob, NLStats(1, 0, 0, 0, 0), storage)
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(termination_condition(cache.fu_new .- cache.fu1,
cache.u,
u_prev,
cache.abstol,
cache.reltol) ||
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
(cache.force_stop = true)

@. u_prev = u
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand All @@ -121,7 +151,10 @@ function perform_step!(cache::GaussNewtonCache{true})
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)

Expand All @@ -138,7 +171,10 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.u_prev = @. cache.u
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand Down
51 changes: 41 additions & 10 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ numerically-difficult nonlinear systems.
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
"""
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
Expand All @@ -85,6 +86,7 @@ numerically-difficult nonlinear systems.
α_geodesic::T
b_uphill::T
min_damping_D::T
termination_condition::TC
end

function set_ad(alg::LevenbergMarquardt{CJ}, ad) where {CJ}
Expand All @@ -97,17 +99,22 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
abstol = nothing,
reltol = nothing),
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D,
termination_condition)
end

@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
du
Expand All @@ -121,6 +128,7 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
DᵀD
JᵀJ
Expand All @@ -145,11 +153,13 @@ end
Jv
mat_tmp
stats::NLStats
tc_storage
end

function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand Down Expand Up @@ -184,15 +194,29 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
fu_tmp = zero(fu1)
mat_tmp = zero(JᵀJ)

return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

atol = _get_tolerance(abstol, tc.abstol, eltype(u))
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
utkarsh530 marked this conversation as resolved.
Show resolved Hide resolved
nothing

return LevenbergMarquardtCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve,
J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, DᵀD,
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage)
end

function perform_step!(cache::LevenbergMarquardtCache{true})
@unpack fu1, f, make_new_J = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if iszero(fu1)
cache.force_stop = true
return nothing
Expand All @@ -205,7 +229,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
Expand Down Expand Up @@ -258,13 +282,18 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
cache.make_new_J = true
end
end
@. u_prev = u
cache.λ *= cache.λ_factor
cache.λ_factor = cache.damping_increase_factor
return nothing
end

function perform_step!(cache::LevenbergMarquardtCache{false})
@unpack fu1, f, make_new_J = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if iszero(fu1)
cache.force_stop = true
return nothing
Expand All @@ -281,7 +310,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

cache.mat_tmp = JᵀJ + λ * DᵀD
# Usual Levenberg-Marquardt step ("velocity").
Expand Down Expand Up @@ -322,7 +352,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
if (1 - β)^b_uphill * loss ≤ loss_old
# Accept step.
cache.u += δ
if loss < cache.abstol
if termination_condition(fu_new, cache.u, u_prev, cache.abstol, cache.reltol)
cache.force_stop = true
return nothing
end
Expand All @@ -334,6 +364,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = true
end
end
cache.u_prev = @. cache.u
cache.λ *= cache.λ_factor
cache.λ_factor = cache.damping_increase_factor
return nothing
Expand Down
Loading
Loading