From 16652770dce3b9a960c3c3daf60d1c234e63f031 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 15 Sep 2022 07:02:49 -0400 Subject: [PATCH 1/3] use LinearSolve.jl --- Project.toml | 4 -- src/NonlinearSolve.jl | 2 +- src/raphson.jl | 77 ++++++++++++++++++++++++++++++++++----- src/utils.jl | 85 ------------------------------------------- 4 files changed, 69 insertions(+), 99 deletions(-) diff --git a/Project.toml b/Project.toml index 44bde3503..d6b3c7d57 100644 --- a/Project.toml +++ b/Project.toml @@ -7,10 +7,8 @@ version = "0.3.22" ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -21,9 +19,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" ArrayInterfaceCore = "0.1.1" FiniteDiff = "2" ForwardDiff = "0.10.3" -IterativeSolvers = "0.9" RecursiveArrayTools = "2" -RecursiveFactorization = "0.1, 0.2" Reexport = "0.2, 1" SciMLBase = "1.32" Setfield = "0.7, 0.8, 1" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index b31f112bb..327a1eaec 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -16,7 +16,7 @@ import RecursiveFactorization abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end -abstract type AbstractNewtonAlgorithm{CS, AD} <: AbstractNonlinearSolveAlgorithm end +abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <: AbstractNonlinearSolveAlgorithm end abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolveAlgorithm end include("utils.jl") diff --git a/src/raphson.jl b/src/raphson.jl index d4c768f16..dc7b4403f 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,12 +1,14 @@ -struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS, AD} - diff_type::DT +struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <: AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} linsolve::L + precs::P end -function NewtonRaphson(; autodiff = true, chunk_size = 12, diff_type = Val{:forward}, - linsolve = DEFAULT_LINSOLVE) - NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type, - linsolve) +function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS) + NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type, + typeof(linsolve), typeof(precs), _unwrap_val(standardtag), + _unwrap_val(concrete_jac)}(linsolve, precs) end mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC} @@ -17,10 +19,64 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC} jac_config::JC end +function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing, + du = nothing, u = nothing, p = nothing, t = nothing, + weight = nothing, solverdata = nothing, + reltol = nothing) where P + A !== nothing && (linsolve = LinearSolve.set_A(linsolve, A)) + b !== nothing && (linsolve = LinearSolve.set_b(linsolve, b)) + linu !== nothing && (linsolve = LinearSolve.set_u(linsolve, linu)) + + Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer : + linsolve.Pl + Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer : + linsolve.Pr + + _Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev, + solverdata) + if (_Pl !== nothing || _Pr !== nothing) + _weight = weight === nothing ? + (linsolve.Pr isa Diagonal ? linsolve.Pr.diag : linsolve.Pr.inner.diag) : + weight + Pl, Pr = wrapprecs(_Pl, _Pr, _weight) + linsolve = LinearSolve.set_prec(linsolve, Pl, Pr) + end + + linres = if reltol === nothing + solve(linsolve; reltol) + else + solve(linsolve; reltol) + end + + return linres +end + +function wrapprecs(_Pl, _Pr, weight) + if _Pl !== nothing + Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(_vec(weight))), + _Pl) + else + Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))) + end + + if _Pr !== nothing + Pr = LinearSolve.ComposePreconditioner(Diagonal(_vec(weight)), _Pr) + else + Pr = Diagonal(_vec(weight)) + end + Pl, Pr +end + function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true}) - uf = JacobianWrapper(f, p) - linsolve = alg.linsolve(Val{:init}, f, u) + uf = JacobianWrapper(f,p) J = false .* u .* u' + + linprob = LinearProblem(W, _vec(zero(u)); u0 = _vec(zero(u))) + Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, nothing, nothing, nothing, nothing, + nothing)..., weight) + linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, + Pl = Pl, Pr = Pr) + du1 = zero(u) tmp = zero(u) if alg_autodiff(alg) @@ -47,9 +103,12 @@ function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{t @unpack J, linsolve, du1 = cache calc_J!(J, solver, cache) # u = u - J \ fu - linsolve(du1, J, fu, true) + linsolve = dolinsolve(alg.precs, solver.linsolve, A = J, b = fu, u = du1, + p = p, reltol = solver.tol) + @set! cache.linsolve = linsolve @. u = u - du1 f(fu, u, p) + if solver.internalnorm(solver.fu) < solver.tol @set! solver.force_stop = true end diff --git a/src/utils.jl b/src/utils.jl index 29411bfa4..cf61b9afc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -100,91 +100,6 @@ function num_types_in_tuple(sig::UnionAll) length(Base.unwrap_unionall(sig).parameters) end -### Default Linsolve - -# Try to be as smart as possible -# lu! if Matrix -# lu if sparse -# gmres if operator - -mutable struct DefaultLinSolve - A::Any - iterable::Any -end -DefaultLinSolve() = DefaultLinSolve(nothing, nothing) - -function (p::DefaultLinSolve)(x, A, b, update_matrix = false; tol = nothing, kwargs...) - if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector - F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt)) - ldiv!(x, F, b) - return nothing - end - if update_matrix - if typeof(A) <: Matrix - blasvendor = BLAS.vendor() - # if the user doesn't use OpenBLAS, we assume that is a better BLAS - # implementation like MKL - # - # RecursiveFactorization seems to be consistantly winning below 100 - # https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213 - if ArrayInterfaceCore.can_setindex(x) && (size(A, 1) <= 100 || - ((blasvendor === :openblas || blasvendor === :openblas64) && - size(A, 1) <= 500)) - p.A = RecursiveFactorization.lu!(A) - else - p.A = lu!(A) - end - elseif typeof(A) <: Tridiagonal - p.A = lu!(A) - elseif typeof(A) <: Union{SymTridiagonal} - p.A = ldlt!(A) - elseif typeof(A) <: Union{Symmetric, Hermitian} - p.A = bunchkaufman!(A) - elseif typeof(A) <: SparseMatrixCSC - p.A = lu(A) - elseif ArrayInterfaceCore.isstructured(A) - p.A = factorize(A) - elseif !(typeof(A) <: AbstractDiffEqOperator) - # Most likely QR is the one that is overloaded - # Works on things like CuArrays - p.A = qr(A) - end - end - - if typeof(A) <: Union{Matrix, SymTridiagonal, Tridiagonal, Symmetric, Hermitian} # No 2-arg form for SparseArrays! - x .= b - ldiv!(p.A, x) - # Missing a little bit of efficiency in a rare case - #elseif typeof(A) <: DiffEqArrayOperator - # ldiv!(x,p.A,b) - elseif ArrayInterfaceCore.isstructured(A) || A isa SparseMatrixCSC - ldiv!(x, p.A, b) - elseif typeof(A) <: AbstractDiffEqOperator - # No good starting guess, so guess zero - if p.iterable === nothing - p.iterable = IterativeSolvers.gmres_iterable!(x, A, b; initially_zero = true, - restart = 5, maxiter = 5, - tol = 1e-16, kwargs...) - p.iterable.reltol = tol - end - x .= false - iter = p.iterable - purge_history!(iter, x, b) - - for residual in iter - end - else - ldiv!(x, p.A, b) - end - return nothing -end - -function (p::DefaultLinSolve)(::Type{Val{:init}}, f, u0_prototype) - DefaultLinSolve() -end - -const DEFAULT_LINSOLVE = DefaultLinSolve() - @inline UNITLESS_ABS2(x) = real(abs2(x)) @inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u) @inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}} From 4f933bc0d0403e71bf47126350a3110ffa0b455c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 15 Sep 2022 07:18:29 -0400 Subject: [PATCH 2/3] remove linear solver packages --- src/NonlinearSolve.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 327a1eaec..c026a1ba2 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -9,8 +9,6 @@ using StaticArrays using RecursiveArrayTools using LinearAlgebra import ArrayInterfaceCore -import IterativeSolvers -import RecursiveFactorization @reexport using SciMLBase From fc05995fb1afcf729ab0ab21b025dfb73dc32b16 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 15 Sep 2022 07:26:14 -0400 Subject: [PATCH 3/3] format --- src/NonlinearSolve.jl | 3 ++- src/raphson.jl | 23 ++++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index c026a1ba2..a1fbd9b4e 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -14,7 +14,8 @@ import ArrayInterfaceCore abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end -abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <: AbstractNonlinearSolveAlgorithm end +abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <: + AbstractNonlinearSolveAlgorithm end abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolveAlgorithm end include("utils.jl") diff --git a/src/raphson.jl b/src/raphson.jl index dc7b4403f..58e71add1 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,14 +1,15 @@ -struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <: AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} +struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <: + AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} linsolve::L precs::P end function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(), - standardtag = Val{true}(), concrete_jac = nothing, - diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS) + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS) NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type, - typeof(linsolve), typeof(precs), _unwrap_val(standardtag), - _unwrap_val(concrete_jac)}(linsolve, precs) + typeof(linsolve), typeof(precs), _unwrap_val(standardtag), + _unwrap_val(concrete_jac)}(linsolve, precs) end mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC} @@ -22,7 +23,7 @@ end function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing, du = nothing, u = nothing, p = nothing, t = nothing, weight = nothing, solverdata = nothing, - reltol = nothing) where P + reltol = nothing) where {P} A !== nothing && (linsolve = LinearSolve.set_A(linsolve, A)) b !== nothing && (linsolve = LinearSolve.set_b(linsolve, b)) linu !== nothing && (linsolve = LinearSolve.set_u(linsolve, linu)) @@ -33,7 +34,7 @@ function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing linsolve.Pr _Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev, - solverdata) + solverdata) if (_Pl !== nothing || _Pr !== nothing) _weight = weight === nothing ? (linsolve.Pr isa Diagonal ? linsolve.Pr.diag : linsolve.Pr.inner.diag) : @@ -68,7 +69,7 @@ function wrapprecs(_Pl, _Pr, weight) end function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true}) - uf = JacobianWrapper(f,p) + uf = JacobianWrapper(f, p) J = false .* u .* u' linprob = LinearProblem(W, _vec(zero(u)); u0 = _vec(zero(u))) @@ -103,12 +104,12 @@ function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{t @unpack J, linsolve, du1 = cache calc_J!(J, solver, cache) # u = u - J \ fu - linsolve = dolinsolve(alg.precs, solver.linsolve, A = J, b = fu, u = du1, + linsolve = dolinsolve(alg.precs, solver.linsolve, A = J, b = fu, u = du1, p = p, reltol = solver.tol) - @set! cache.linsolve = linsolve + @set! cache.linsolve = linsolve @. u = u - du1 f(fu, u, p) - + if solver.internalnorm(solver.fu) < solver.tol @set! solver.force_stop = true end