Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use LinearSolve.jl #86

Merged
merged 3 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ version = "0.3.22"
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -21,9 +19,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
ArrayInterfaceCore = "0.1.1"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
IterativeSolvers = "0.9"
RecursiveArrayTools = "2"
RecursiveFactorization = "0.1, 0.2"
Reexport = "0.2, 1"
SciMLBase = "1.32"
Setfield = "0.7, 0.8, 1"
Expand Down
5 changes: 2 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ using StaticArrays
using RecursiveArrayTools
using LinearAlgebra
import ArrayInterfaceCore
import IterativeSolvers
import RecursiveFactorization

@reexport using SciMLBase

abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD} <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
AbstractNonlinearSolveAlgorithm end
abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolveAlgorithm end

include("utils.jl")
Expand Down
76 changes: 68 additions & 8 deletions src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS, AD}
diff_type::DT
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
end

function NewtonRaphson(; autodiff = true, chunk_size = 12, diff_type = Val{:forward},
linsolve = DEFAULT_LINSOLVE)
NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type,
linsolve)
function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS)
NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac)}(linsolve, precs)
end

mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
Expand All @@ -17,10 +20,64 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
jac_config::JC
end

function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing,
weight = nothing, solverdata = nothing,
reltol = nothing) where {P}
A !== nothing && (linsolve = LinearSolve.set_A(linsolve, A))
b !== nothing && (linsolve = LinearSolve.set_b(linsolve, b))
linu !== nothing && (linsolve = LinearSolve.set_u(linsolve, linu))

Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
linsolve.Pl
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
linsolve.Pr

_Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev,
solverdata)
if (_Pl !== nothing || _Pr !== nothing)
_weight = weight === nothing ?
(linsolve.Pr isa Diagonal ? linsolve.Pr.diag : linsolve.Pr.inner.diag) :
weight
Pl, Pr = wrapprecs(_Pl, _Pr, _weight)
linsolve = LinearSolve.set_prec(linsolve, Pl, Pr)
end

linres = if reltol === nothing
solve(linsolve; reltol)
else
solve(linsolve; reltol)
end

return linres
end

function wrapprecs(_Pl, _Pr, weight)
if _Pl !== nothing
Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
_Pl)
else
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
end

if _Pr !== nothing
Pr = LinearSolve.ComposePreconditioner(Diagonal(_vec(weight)), _Pr)
else
Pr = Diagonal(_vec(weight))
end
Pl, Pr
end

function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
uf = JacobianWrapper(f, p)
linsolve = alg.linsolve(Val{:init}, f, u)
J = false .* u .* u'

linprob = LinearProblem(W, _vec(zero(u)); u0 = _vec(zero(u)))
Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr)

du1 = zero(u)
tmp = zero(u)
if alg_autodiff(alg)
Expand All @@ -47,9 +104,12 @@ function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{t
@unpack J, linsolve, du1 = cache
calc_J!(J, solver, cache)
# u = u - J \ fu
linsolve(du1, J, fu, true)
linsolve = dolinsolve(alg.precs, solver.linsolve, A = J, b = fu, u = du1,
p = p, reltol = solver.tol)
@set! cache.linsolve = linsolve
@. u = u - du1
f(fu, u, p)

if solver.internalnorm(solver.fu) < solver.tol
@set! solver.force_stop = true
end
Expand Down
85 changes: 0 additions & 85 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,91 +100,6 @@ function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

### Default Linsolve

# Try to be as smart as possible
# lu! if Matrix
# lu if sparse
# gmres if operator

mutable struct DefaultLinSolve
A::Any
iterable::Any
end
DefaultLinSolve() = DefaultLinSolve(nothing, nothing)

function (p::DefaultLinSolve)(x, A, b, update_matrix = false; tol = nothing, kwargs...)
if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector
F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt))
ldiv!(x, F, b)
return nothing
end
if update_matrix
if typeof(A) <: Matrix
blasvendor = BLAS.vendor()
# if the user doesn't use OpenBLAS, we assume that is a better BLAS
# implementation like MKL
#
# RecursiveFactorization seems to be consistantly winning below 100
# https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213
if ArrayInterfaceCore.can_setindex(x) && (size(A, 1) <= 100 ||
((blasvendor === :openblas || blasvendor === :openblas64) &&
size(A, 1) <= 500))
p.A = RecursiveFactorization.lu!(A)
else
p.A = lu!(A)
end
elseif typeof(A) <: Tridiagonal
p.A = lu!(A)
elseif typeof(A) <: Union{SymTridiagonal}
p.A = ldlt!(A)
elseif typeof(A) <: Union{Symmetric, Hermitian}
p.A = bunchkaufman!(A)
elseif typeof(A) <: SparseMatrixCSC
p.A = lu(A)
elseif ArrayInterfaceCore.isstructured(A)
p.A = factorize(A)
elseif !(typeof(A) <: AbstractDiffEqOperator)
# Most likely QR is the one that is overloaded
# Works on things like CuArrays
p.A = qr(A)
end
end

if typeof(A) <: Union{Matrix, SymTridiagonal, Tridiagonal, Symmetric, Hermitian} # No 2-arg form for SparseArrays!
x .= b
ldiv!(p.A, x)
# Missing a little bit of efficiency in a rare case
#elseif typeof(A) <: DiffEqArrayOperator
# ldiv!(x,p.A,b)
elseif ArrayInterfaceCore.isstructured(A) || A isa SparseMatrixCSC
ldiv!(x, p.A, b)
elseif typeof(A) <: AbstractDiffEqOperator
# No good starting guess, so guess zero
if p.iterable === nothing
p.iterable = IterativeSolvers.gmres_iterable!(x, A, b; initially_zero = true,
restart = 5, maxiter = 5,
tol = 1e-16, kwargs...)
p.iterable.reltol = tol
end
x .= false
iter = p.iterable
purge_history!(iter, x, b)

for residual in iter
end
else
ldiv!(x, p.A, b)
end
return nothing
end

function (p::DefaultLinSolve)(::Type{Val{:init}}, f, u0_prototype)
DefaultLinSolve()
end

const DEFAULT_LINSOLVE = DefaultLinSolve()

@inline UNITLESS_ABS2(x) = real(abs2(x))
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
Expand Down