Skip to content

Commit

Permalink
Fast General Klement Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 21, 2023
1 parent 1e4cfde commit 63960a8
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 12 deletions.
13 changes: 8 additions & 5 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters

Check warning on line 58 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L58

Added line #L58 was not covered by tests
else
cache.retcode = ReturnCode.Success
end
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
Expand Down Expand Up @@ -85,7 +88,7 @@ import PrecompileTools
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
nothing)
PseudoTransient(), GeneralBroyden(), nothing)

for alg in precompile_algs
solve(prob, alg, abstol = T(1e-2))
Expand Down
5 changes: 1 addition & 4 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.dfu = cache.fu2 .- cache.fu
if cache.resets < cache.max_resets &&
(all(x -> abs(x) 1e-12, cache.du) || all(x -> abs(x) 1e-12, cache.dfu))
J⁻¹ = similar(cache.J⁻¹)
fill!(J⁻¹, 0)
J⁻¹[diagind(J⁻¹)] .= T(1)
cache.J⁻¹ = J⁻¹
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
cache.resets += 1

Check warning on line 117 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
else
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
Expand Down
4 changes: 1 addition & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ end
]
else
[
# FIXME: Broyden and Klement are type unstable
# (upstream SimpleNonlinearSolve.jl issue)
!iip ? :(Klement()) : nothing, # Klement not yet implemented for IIP
:(GeneralBroyden()),
:(GeneralKlement()),
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
:(TrustRegion(; linsolve, precs, adkwargs...)),
Expand Down
190 changes: 190 additions & 0 deletions src/klement.jl
Original file line number Diff line number Diff line change
@@ -1 +1,191 @@
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
max_resets::Int
linsolve
precs
linesearch
singular_tolerance
end

function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,

Check warning on line 9 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L9

Added line #L9 was not covered by tests
linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance)

Check warning on line 12 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L11-L12

Added lines #L11 - L12 were not covered by tests
end

@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
fu
fu2
du
p
linsolve
J
J_cache
J_cache2
Jᵀ²du
Jdu
resets
singular_tolerance
force_stop
maxiters::Int
internalnorm
retcode::ReturnCode.T
abstol
prob
stats::NLStats
lscache
end

get_fu(cache::GeneralKlementCache) = cache.fu

Check warning on line 41 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L41

Added line #L41 was not covered by tests

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlement, args...;

Check warning on line 43 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L43

Added line #L43 was not covered by tests
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu = evaluate_f(prob, u)
J = __init_identity_jacobian(u, fu)

Check warning on line 49 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L46-L49

Added lines #L46 - L49 were not covered by tests

if u isa Number
linsolve = nothing

Check warning on line 52 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L51-L52

Added lines #L51 - L52 were not covered by tests
else
weight = similar(u)
recursivefill!(weight, true)
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,

Check warning on line 56 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L54-L56

Added lines #L54 - L56 were not covered by tests
nothing)..., weight)
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,

Check warning on line 59 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
linsolve_kwargs...)
end

singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) :

Check warning on line 63 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L63

Added line #L63 was not covered by tests
eltype(u)(alg.singular_tolerance)

return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,

Check warning on line 66 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L66

Added line #L66 was not covered by tests
J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false,
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end

function perform_step!(cache::GeneralKlementCache{true})
@unpack u, fu, f, p, alg, J, linsolve, du = cache
T = eltype(J)

Check warning on line 74 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L72-L74

Added lines #L72 - L74 were not covered by tests

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 81 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L77-L81

Added lines #L77 - L81 were not covered by tests
end
fill!(J, zero(T))
J[diagind(J)] .= T(1)
cache.resets += 1

Check warning on line 85 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L83-L85

Added lines #L83 - L85 were not covered by tests
end

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du),

Check warning on line 89 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L89

Added line #L89 was not covered by tests
p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 91 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L91

Added line #L91 was not covered by tests

# Line Search
α = perform_linesearch!(cache.lscache, u, du)
axpy!(α, du, u)
f(cache.fu2, u, p)

Check warning on line 96 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L94-L96

Added lines #L94 - L96 were not covered by tests

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1

Check warning on line 101 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L98-L101

Added lines #L98 - L101 were not covered by tests

cache.force_stop && return nothing

Check warning on line 103 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L103

Added line #L103 was not covered by tests

# Update the Jacobian
cache.J_cache .= cache.J' .^ 2
cache.Jdu .= _vec(du) .^ 2
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
mul!(cache.Jdu, J, _vec(du))
cache.fu .= cache.fu2 .- cache.fu
cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
cache.J_cache .*= J
mul!(cache.J_cache2, cache.J_cache, J)
J .+= cache.J_cache2

Check warning on line 115 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L106-L115

Added lines #L106 - L115 were not covered by tests

cache.fu .= cache.fu2

Check warning on line 117 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L117

Added line #L117 was not covered by tests

return nothing

Check warning on line 119 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L119

Added line #L119 was not covered by tests
end

function perform_step!(cache::GeneralKlementCache{false})
@unpack fu, f, p, alg, J, linsolve = cache
T = eltype(J)

Check warning on line 124 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L122-L124

Added lines #L122 - L124 were not covered by tests

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 131 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L127-L131

Added lines #L127 - L131 were not covered by tests
end
cache.J = __init_identity_jacobian(u, fu)
cache.resets += 1

Check warning on line 134 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L133-L134

Added lines #L133 - L134 were not covered by tests
end

# u = u - J \ fu
if linsolve === nothing
cache.du = -fu / cache.J

Check warning on line 139 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu),

Check warning on line 141 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L141

Added line #L141 was not covered by tests
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 143 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L143

Added line #L143 was not covered by tests
end

# Line Search
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = @. cache.u + α * cache.du # `u` might not support mutation
cache.fu2 = f(cache.u, p)

Check warning on line 149 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L147-L149

Added lines #L147 - L149 were not covered by tests

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1

Check warning on line 154 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L151-L154

Added lines #L151 - L154 were not covered by tests

cache.force_stop && return nothing

Check warning on line 156 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L156

Added line #L156 was not covered by tests

# Update the Jacobian
cache.J_cache = cache.J' .^ 2
cache.Jdu = _vec(cache.du) .^ 2
cache.Jᵀ²du = cache.J_cache * cache.Jdu
cache.Jdu = J * _vec(cache.du)
cache.fu = cache.fu2 .- cache.fu
cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
cache.J = J .+ cache.J_cache

Check warning on line 166 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L159-L166

Added lines #L159 - L166 were not covered by tests

cache.fu = cache.fu2

Check warning on line 168 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L168

Added line #L168 was not covered by tests

return nothing

Check warning on line 170 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L170

Added line #L170 was not covered by tests
end

function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,

Check warning on line 173 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L173

Added line #L173 was not covered by tests
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)

Check warning on line 178 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L175-L178

Added lines #L175 - L178 were not covered by tests
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)

Check warning on line 182 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache

Check warning on line 190 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L184-L190

Added lines #L184 - L190 were not covered by tests
end
11 changes: 11 additions & 0 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,14 @@ end

test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testset "GeneralKlement 23 Test Problems" begin
alg_ops = (GeneralKlement(),
GeneralKlement(; linesearch = BackTracking()))

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 8, 18, 22]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 7, 11, 12, 18, 22]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
89 changes: 89 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -754,3 +754,92 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
end

# --- GeneralKlement tests ---

@testset "GeneralKlement" begin
function benchmark_nlsolve_oop(f, u0, p = 2.0; linesearch = LineSearch())
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p = 2.0; linesearch = LineSearch())
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9)
end

@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (Static(),
StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote())

linesearch = LineSearch(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)

@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
sol = benchmark_nlsolve_oop(quadratic_f, u0; linesearch)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
GeneralKlement(; linesearch), abstol = 1e-9)
@test (@ballocated solve!($cache)) < 200
end

@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
ad isa AutoZygote && continue
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linesearch)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
GeneralKlement(; linesearch), abstol = 1e-9)
@test (@ballocated solve!($cache)) 64
end
end

@testset "[OOP] [Immutable AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] [Scalar AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p) 1 / (2 * sqrt(p))
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ForwardDiff.jacobian(t, p)

# Iterator interface
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN, GeneralKlement(); maxiters = 100, abstol = 1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
return sols
end
p = range(0.01, 2, length = 200)
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
end

0 comments on commit 63960a8

Please sign in to comment.