Skip to content

Commit

Permalink
Standardize parts of SIAM FANL Equations
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2023
1 parent 73005ca commit f2edda0
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 87 deletions.
3 changes: 2 additions & 1 deletion docs/src/api/siamfanlequations.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SIAMFANLEquations.jl

This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
This is an extension for importing solvers from
[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
7 changes: 7 additions & 0 deletions docs/src/solvers/NonlinearSystemSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,10 @@ Newton-Krylov form. However, KINSOL is known to be less stable than some other
implementations, as it has no line search or globalizer (trust region).

- `KINSOL()`: The KINSOL method of the SUNDIALS C library

### SIAMFANLEquations.jl

SIAMFANLEquations.jl is a wrapper for the methods in the SIAMFANLEquations.jl library.

- `SIAMFANLEquationsJL()`: A wrapper for using the methods in
[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl)
2 changes: 1 addition & 1 deletion ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
end

abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))

Check warning on line 71 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L71

Added line #L71 was not covered by tests

original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
Expand Down
134 changes: 62 additions & 72 deletions ext/NonlinearSolveSIAMFANLEquationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,58 @@ module NonlinearSolveSIAMFANLEquationsExt

using NonlinearSolve, SciMLBase
using SIAMFANLEquations
import ConcreteStructs: @concrete
import UnPack: @unpack

function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing,
reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
@assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"
@inline function __siam_fanl_equations_retcode_mapping(sol)
if sol.errcode == 0
return ReturnCode.Success
elseif sol.errcode == 10
return ReturnCode.MaxIters
elseif sol.errcode == 1
return ReturnCode.Failure
elseif sol.errcode == -1
return ReturnCode.Default

Check warning on line 15 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L7-L15

Added lines #L7 - L15 were not covered by tests
end
end

# pseudo transient continuation has a fixed cost per iteration, iteration statistics are
# not interesting here.
@inline function __siam_fanl_equations_stats_mapping(method, sol)
method === :pseudotransient && return nothing
return SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0,

Check warning on line 23 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L21-L23

Added lines #L21 - L23 were not covered by tests
sum(sol.stats.iarm))
end

@unpack method, show_trace, delta, linsolve = alg
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...;

Check warning on line 27 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L27

Added line #L27 was not covered by tests
abstol = nothing, reltol = nothing, alias_u0::Bool = false, maxiters = 1000,
termination_condition = nothing, show_trace::Val{ShT} = Val(false),
kwargs...) where {ShT}
@assert (termination_condition ===

Check warning on line 31 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L31

Added line #L31 was not covered by tests
nothing)||(termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"

@unpack method, delta, linsolve = alg

Check warning on line 34 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L34

Added line #L34 was not covered by tests

iip = SciMLBase.isinplace(prob)
T = eltype(prob.u0)

atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol
atol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(prob.u0))

Check warning on line 39 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L38-L39

Added lines #L38 - L39 were not covered by tests

if prob.u0 isa Number
f! = if iip
function (u)
du = similar(u)
prob.f(du, u, prob.p)
return du
end
else
u -> prob.f(u, prob.p)
end
f = (u) -> prob.f(u, prob.p)

Check warning on line 42 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L42

Added line #L42 was not covered by tests

if method == :newton
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
sol = nsolsc(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)

Check warning on line 45 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L45

Added line #L45 was not covered by tests
elseif method == :pseudotransient
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace)
sol = ptcsolsc(f, prob.u0; delta0 = delta, maxit = maxiters, atol, rtol,

Check warning on line 47 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L47

Added line #L47 was not covered by tests
printerr = ShT)
elseif method == :secant
sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
sol = secant(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)

Check warning on line 50 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L50

Added line #L50 was not covered by tests
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
retcode = __siam_fanl_equations_retcode_mapping(sol)
stats = __siam_fanl_equations_stats_mapping(method, sol)
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode,

Check warning on line 55 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L53-L55

Added lines #L53 - L55 were not covered by tests
stats, original = sol)
else
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end
Expand All @@ -71,67 +78,50 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
if linsolve !== nothing
# Allocate ahead for Krylov basis
JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N)
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
# `linsolve` as a Symbol to keep unified interface with other EXTs,
# SIAMFANLEquations directly use String to choose between different linear solvers
linsolve_alg = String(linsolve)

if method == :newton
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,

Check warning on line 86 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L86

Added line #L86 was not covered by tests
rtol, printerr = ShT)
elseif method == :pseudotransient
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,

Check warning on line 89 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L89

Added line #L89 was not covered by tests
rtol, printerr = ShT)
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)

retcode = __siam_fanl_equations_retcode_mapping(sol)
stats = __siam_fanl_equations_stats_mapping(method, sol)
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode,

Check warning on line 95 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L93-L95

Added lines #L93 - L95 were not covered by tests
stats, original = sol)
end

# Allocate ahead for Jacobian
FPS = zeros(T, N, N)
if prob.f.jac === nothing
# Use the built-in Jacobian machinery
if method == :newton
sol = nsol(f!, u, FS, FPS;
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
printerr = show_trace)
sol = nsol(f!, u, FS, FPS; sham = 1, atol, rtol, maxit = maxiters,

Check warning on line 104 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L104

Added line #L104 was not covered by tests
printerr = ShT)
elseif method == :pseudotransient
sol = ptcsol(f!, u, FS, FPS;
atol = atol, rtol = rtol, maxit = maxiters,
delta0 = delta, printerr = show_trace)
sol = ptcsol(f!, u, FS, FPS; atol, rtol, maxit = maxiters,

Check warning on line 107 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L107

Added line #L107 was not covered by tests
delta0 = delta, printerr = ShT)
end
else
AJ!(J, u, x) = prob.f.jac(J, x, prob.p)
if method == :newton
sol = nsol(f!, u, FS, FPS, AJ!;
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
printerr = show_trace)
sol = nsol(f!, u, FS, FPS, AJ!; sham = 1, atol, rtol, maxit = maxiters,

Check warning on line 113 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L113

Added line #L113 was not covered by tests
printerr = ShT)
elseif method == :pseudotransient
sol = ptcsol(f!, u, FS, FPS, AJ!;
atol = atol, rtol = rtol, maxit = maxiters,
delta0 = delta, printerr = show_trace)
sol = ptcsol(f!, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters,

Check warning on line 116 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L116

Added line #L116 was not covered by tests
delta0 = delta, printerr = ShT)
end
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default
end

# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
retcode = __siam_fanl_equations_retcode_mapping(sol)
stats = __siam_fanl_equations_stats_mapping(method, sol)
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats,

Check warning on line 123 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L121-L123

Added lines #L121 - L123 were not covered by tests
original = sol)
end

end
end
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ export RadiusUpdateSchemes
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
Broyden, Klement, LimitedMemoryBroyden
export LeastSquaresOptimJL,
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL,
SIAMFANLEquationsJL
export NonlinearSolvePolyAlgorithm,
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg

Expand Down
15 changes: 6 additions & 9 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
end

"""
SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)
Expand Down Expand Up @@ -364,13 +363,11 @@ function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
end

"""
SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothing)
### Keyword Arguments
- `method`: the choice of method for solving the nonlinear system.
- `show_trace`: whether to show the trace.
- `delta`: initial pseudo time step, default is 1e-3.
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.
Expand All @@ -380,16 +377,16 @@ end
- `:pseudotransient`: Pseudo transient method.
- `:secant`: Secant method for scalar equations.
"""
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
@concrete struct SIAMFANLEquationsJL{L <: Union{Symbol, Nothing}} <:
AbstractNonlinearSolveAlgorithm
method::Symbol
show_trace::Bool
delta
linsolve::Union{Symbol, Nothing}
linsolve::L
end

function SIAMFANLEquationsJL(; method = :newton, show_trace = false, delta = 1e-3, linsolve = nothing)
function SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothing)

Check warning on line 387 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L387

Added line #L387 was not covered by tests
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")
end
return SIAMFANLEquationsJL(method, show_trace, delta, linsolve)
return SIAMFANLEquationsJL(method, show_trace, delta, linsolve)

Check warning on line 391 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L391

Added line #L391 was not covered by tests
end
6 changes: 3 additions & 3 deletions test/siamfanlequations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end
f_tol(u, p) = u^2 - 2
prob_tol = NonlinearProblem(f_tol, 1.0)
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11]
for method = [:newton, :pseudotransient, :secant]
for method in [:newton, :pseudotransient, :secant]
sol = solve(prob_tol, SIAMFANLEquationsJL(method = method), abstol = tol)
@test abs(sol.u[1] - sqrt(2)) < tol
end
Expand Down Expand Up @@ -141,12 +141,12 @@ f = NonlinearFunction(f!, jac = j!)
p = A

ProbN = NonlinearProblem(f, init, p)
for method = [:newton, :pseudotransient]
for method in [:newton, :pseudotransient]
sol = solve(ProbN, SIAMFANLEquationsJL(method = method), reltol = 1e-8, abstol = 1e-8)
end

#= doesn't support complex numbers handling
init = ones(Complex{Float64}, 152);
ProbN = NonlinearProblem(f, init, p)
sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8)
=#
=#

0 comments on commit f2edda0

Please sign in to comment.