Skip to content

Commit

Permalink
avoid re-computing matrix factorization
Browse files Browse the repository at this point in the history
  • Loading branch information
frankschae committed Nov 2, 2023
1 parent 7018113 commit b061a52
Showing 1 changed file with 44 additions and 23 deletions.
67 changes: 44 additions & 23 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ}
end

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, reuse = true, reusetol = 1e-6,
adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, reuse = true, reusetol = 1e-6,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad,
Expand All @@ -70,7 +70,7 @@ end
f
alg
u
uprev
u_prev
Δu
fu1
fu2
Expand Down Expand Up @@ -99,7 +99,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
uprev = deepcopy(u0)
u_prev = deepcopy(u0)
Δu = zero(u0)

fu1 = evaluate_f(prob, u)
Expand All @@ -109,34 +109,40 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu1, u,
termination_condition)

return NewtonRaphsonCache{iip}(f, alg, u, uprev, Δu, fu1, fu2, du, p, uf, linsolve, J,
return NewtonRaphsonCache{iip}(f, alg, u, u_prev, Δu, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), tc_cache)
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, uprev, Δu, fu1, f, p, alg, J, linsolve, du = cache
@unpack u, u_prev, Δu, fu1, f, p, alg, J, linsolve, du = cache
@unpack reuse = alg

if reuse
# check how far we stepped
@. Δu += u - uprev
@. Δu += u - u_prev
update = cache.internalnorm(Δu) > alg.reusetol
if update || cache.stats.njacs == 0
jacobian!!(J, cache)
cache.stats.njacs += 1
Δu .*= false
# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
else
# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
end
else
jacobian!!(J, cache)
cache.stats.njacs += 1
# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
end
cache.uprev .= u

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache

# Line Search
Expand All @@ -154,33 +160,48 @@ function perform_step!(cache::NewtonRaphsonCache{true})
end

function perform_step!(cache::NewtonRaphsonCache{false})
@unpack u, uprev, Δu, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, Δu, fu1, f, p, alg, linsolve = cache
@unpack reuse = alg

if reuse
# check how far we stepped
cache.Δu += u - uprev
cache.Δu += u - u_prev
update = cache.internalnorm(Δu) > alg.reusetol
if update || cache.stats.njacs == 0
cache.J = jacobian!!(cache.J, cache)
cache.stats.njacs += 1
cache.Δu *= false
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
else
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
end
else
cache.J = jacobian!!(cache.J, cache)
# cache.Δu *= false
cache.stats.njacs += 1
end

cache.uprev = u

# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
end

# Line Search
Expand Down

0 comments on commit b061a52

Please sign in to comment.