Skip to content

Commit

Permalink
rebase and add update heuristic.. need to fix hyperparameter
Browse files Browse the repository at this point in the history
  • Loading branch information
frankschae committed Oct 19, 2023
1 parent a00ec0b commit 2c4009a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 13 deletions.
76 changes: 64 additions & 12 deletions src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, reuse = true, reusetol = 1e-6, adkwargs...)
An advanced NewtonRaphson implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
Expand Down Expand Up @@ -29,29 +29,48 @@ for large-scale and numerically-difficult nonlinear systems.
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
- `reuse`: Determines if the Jacobian is reused between (quasi-)Newton steps. Defaults to
`true`. If `true` we check how far we stepped with the same Jacobian, and automatically
take a new Jacobian if we stepped more than `reusetol` or if convergence slows or starts
to diverge. If `false`, the Jacobian is updated in each step.
"""
@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
linesearch
reusetol
reuse::Bool
end

function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ}
return NewtonRaphson{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
return NewtonRaphson{CJ}(ad,
alg.linsolve,
alg.precs,
alg.linesearch,
alg.reusetol,
alg.reuse)
end

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, reuse = true, reusetol = 1e-6,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs,
linesearch,
reusetol,
reuse)
end

@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
uprev
Δu
fu1
fu2
du
Expand All @@ -76,18 +95,36 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
uprev = deepcopy(u0)
Δu = zero(u0)

fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)

return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
return NewtonRaphsonCache{iip}(f, alg, u, uprev, Δu, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)))
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, fu1, f, p, alg, J, linsolve, du = cache
jacobian!!(J, cache)
@unpack u, uprev, Δu, fu1, f, p, alg, J, linsolve, du = cache
@unpack reuse = alg

if reuse
# check how far we stepped
@. Δu += u - uprev
update = cache.internalnorm(Δu) > alg.reusetol
if update || cache.stats.njacs == 0
jacobian!!(J, cache)
cache.stats.njacs += 1
Δu .*= false
end
else
jacobian!!(J, cache)
cache.stats.njacs += 1

Check warning on line 125 in src/raphson.jl

View check run for this annotation

Codecov / codecov/patch

src/raphson.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
end
cache.uprev .= u

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
Expand All @@ -101,16 +138,32 @@ function perform_step!(cache::NewtonRaphsonCache{true})

cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function perform_step!(cache::NewtonRaphsonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache
@unpack u, uprev, Δu, fu1, f, p, alg, linsolve = cache
@unpack reuse = alg

if reuse
# check how far we stepped
cache.Δu += u - uprev
update = cache.internalnorm(Δu) > alg.reusetol
if update || cache.stats.njacs == 0
cache.J = jacobian!!(cache.J, cache)
cache.stats.njacs += 1
cache.Δu *= false
end
else
cache.J = jacobian!!(cache.J, cache)

Check warning on line 160 in src/raphson.jl

View check run for this annotation

Codecov / codecov/patch

src/raphson.jl#L160

Added line #L160 was not covered by tests
# cache.Δu *= false
cache.stats.njacs += 1

Check warning on line 162 in src/raphson.jl

View check run for this annotation

Codecov / codecov/patch

src/raphson.jl#L162

Added line #L162 was not covered by tests
end

cache.uprev = u

cache.J = jacobian!!(cache.J, cache)
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
Expand All @@ -127,7 +180,6 @@ function perform_step!(cache::NewtonRaphsonCache{false})

cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
Expand Down
4 changes: 3 additions & 1 deletion test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ end

# NewtonRaphson
@testset "NewtonRaphson test problem library" begin
alg_ops = (NewtonRaphson(),)
alg_ops = (NewtonRaphson(; reuse = false),
NewtonRaphson(; reuse = true, reusetol = 1e-6))

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 6]
broken_tests[alg_ops[2]] = [1, 6]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Expand Down

0 comments on commit 2c4009a

Please sign in to comment.