From f2edda07848ce459b1c2e9221b363424b935c4ce Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 24 Dec 2023 14:34:13 -0500 Subject: [PATCH] Standardize parts of SIAM FANL Equations --- docs/src/api/siamfanlequations.md | 3 +- docs/src/solvers/NonlinearSystemSolvers.md | 7 ++ ext/NonlinearSolveNLsolveExt.jl | 2 +- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 134 ++++++++++----------- src/NonlinearSolve.jl | 3 +- src/extension_algs.jl | 15 +-- test/siamfanlequations.jl | 6 +- 7 files changed, 83 insertions(+), 87 deletions(-) diff --git a/docs/src/api/siamfanlequations.md b/docs/src/api/siamfanlequations.md index 2848a18ed..5ee36ba12 100644 --- a/docs/src/api/siamfanlequations.md +++ b/docs/src/api/siamfanlequations.md @@ -1,6 +1,7 @@ # SIAMFANLEquations.jl -This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML +This is an extension for importing solvers from +[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML interface. Note that these solvers do not come by default, and thus one needs to install the package before using these solvers: diff --git a/docs/src/solvers/NonlinearSystemSolvers.md b/docs/src/solvers/NonlinearSystemSolvers.md index bba941d86..c15948814 100644 --- a/docs/src/solvers/NonlinearSystemSolvers.md +++ b/docs/src/solvers/NonlinearSystemSolvers.md @@ -143,3 +143,10 @@ Newton-Krylov form. However, KINSOL is known to be less stable than some other implementations, as it has no line search or globalizer (trust region). - `KINSOL()`: The KINSOL method of the SUNDIALS C library + +### SIAMFANLEquations.jl + +SIAMFANLEquations.jl is a wrapper for the methods in the SIAMFANLEquations.jl library. + + - `SIAMFANLEquationsJL()`: A wrapper for using the methods in + [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) diff --git a/ext/NonlinearSolveNLsolveExt.jl b/ext/NonlinearSolveNLsolveExt.jl index 8ee04532e..fc1218527 100644 --- a/ext/NonlinearSolveNLsolveExt.jl +++ b/ext/NonlinearSolveNLsolveExt.jl @@ -68,7 +68,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff) end - abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u)) + abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0)) original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method, store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 47ecc96c9..e44beea45 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -2,51 +2,58 @@ module NonlinearSolveSIAMFANLEquationsExt using NonlinearSolve, SciMLBase using SIAMFANLEquations -import ConcreteStructs: @concrete import UnPack: @unpack -function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing, - reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...) - @assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!" +@inline function __siam_fanl_equations_retcode_mapping(sol) + if sol.errcode == 0 + return ReturnCode.Success + elseif sol.errcode == 10 + return ReturnCode.MaxIters + elseif sol.errcode == 1 + return ReturnCode.Failure + elseif sol.errcode == -1 + return ReturnCode.Default + end +end + +# pseudo transient continuation has a fixed cost per iteration, iteration statistics are +# not interesting here. +@inline function __siam_fanl_equations_stats_mapping(method, sol) + method === :pseudotransient && return nothing + return SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, + sum(sol.stats.iarm)) +end - @unpack method, show_trace, delta, linsolve = alg +function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; + abstol = nothing, reltol = nothing, alias_u0::Bool = false, maxiters = 1000, + termination_condition = nothing, show_trace::Val{ShT} = Val(false), + kwargs...) where {ShT} + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!" + + @unpack method, delta, linsolve = alg iip = SciMLBase.isinplace(prob) - T = eltype(prob.u0) - atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol - rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol + atol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0)) + rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(prob.u0)) if prob.u0 isa Number - f! = if iip - function (u) - du = similar(u) - prob.f(du, u, prob.p) - return du - end - else - u -> prob.f(u, prob.p) - end + f = (u) -> prob.f(u, prob.p) if method == :newton - sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + sol = nsolsc(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT) elseif method == :pseudotransient - sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace) + sol = ptcsolsc(f, prob.u0; delta0 = delta, maxit = maxiters, atol, rtol, + printerr = ShT) elseif method == :secant - sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + sol = secant(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT) end - if sol.errcode == 0 - retcode = ReturnCode.Success - elseif sol.errcode == 10 - retcode = ReturnCode.MaxIters - elseif sol.errcode == 1 - retcode = ReturnCode.Failure - elseif sol.errcode == -1 - retcode = ReturnCode.Default - end - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) - return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) + retcode = __siam_fanl_equations_retcode_mapping(sol) + stats = __siam_fanl_equations_stats_mapping(method, sol) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, + stats, original = sol) else u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end @@ -71,26 +78,22 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg if linsolve !== nothing # Allocate ahead for Krylov basis JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N) - # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers + # `linsolve` as a Symbol to keep unified interface with other EXTs, + # SIAMFANLEquations directly use String to choose between different linear solvers linsolve_alg = String(linsolve) if method == :newton - sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol, + rtol, printerr = ShT) elseif method == :pseudotransient - sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) - end - - if sol.errcode == 0 - retcode = ReturnCode.Success - elseif sol.errcode == 10 - retcode = ReturnCode.MaxIters - elseif sol.errcode == 1 - retcode = ReturnCode.Failure - elseif sol.errcode == -1 - retcode = ReturnCode.Default + sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol, + rtol, printerr = ShT) end - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) - return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) + + retcode = __siam_fanl_equations_retcode_mapping(sol) + stats = __siam_fanl_equations_stats_mapping(method, sol) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, + stats, original = sol) end # Allocate ahead for Jacobian @@ -98,40 +101,27 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg if prob.f.jac === nothing # Use the built-in Jacobian machinery if method == :newton - sol = nsol(f!, u, FS, FPS; - sham=1, atol = atol, rtol = rtol, maxit = maxiters, - printerr = show_trace) + sol = nsol(f!, u, FS, FPS; sham = 1, atol, rtol, maxit = maxiters, + printerr = ShT) elseif method == :pseudotransient - sol = ptcsol(f!, u, FS, FPS; - atol = atol, rtol = rtol, maxit = maxiters, - delta0 = delta, printerr = show_trace) + sol = ptcsol(f!, u, FS, FPS; atol, rtol, maxit = maxiters, + delta0 = delta, printerr = ShT) end else AJ!(J, u, x) = prob.f.jac(J, x, prob.p) if method == :newton - sol = nsol(f!, u, FS, FPS, AJ!; - sham=1, atol = atol, rtol = rtol, maxit = maxiters, - printerr = show_trace) + sol = nsol(f!, u, FS, FPS, AJ!; sham = 1, atol, rtol, maxit = maxiters, + printerr = ShT) elseif method == :pseudotransient - sol = ptcsol(f!, u, FS, FPS, AJ!; - atol = atol, rtol = rtol, maxit = maxiters, - delta0 = delta, printerr = show_trace) + sol = ptcsol(f!, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters, + delta0 = delta, printerr = ShT) end end - if sol.errcode == 0 - retcode = ReturnCode.Success - elseif sol.errcode == 10 - retcode = ReturnCode.MaxIters - elseif sol.errcode == 1 - retcode = ReturnCode.Failure - elseif sol.errcode == -1 - retcode = ReturnCode.Default - end - - # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here. - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) - return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) + retcode = __siam_fanl_equations_retcode_mapping(sol) + stats = __siam_fanl_equations_stats_mapping(method, sol) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, + original = sol) end -end \ No newline at end of file +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index bf77135c8..646af5cd2 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -237,7 +237,8 @@ export RadiusUpdateSchemes export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient, Broyden, Klement, LimitedMemoryBroyden export LeastSquaresOptimJL, - FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL + FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, + SIAMFANLEquationsJL export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg diff --git a/src/extension_algs.jl b/src/extension_algs.jl index d14274644..c70466131 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -247,7 +247,6 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = end """ - SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false, orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000) @@ -364,13 +363,11 @@ function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing, end """ - - SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing) + SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothing) ### Keyword Arguments - `method`: the choice of method for solving the nonlinear system. - - `show_trace`: whether to show the trace. - `delta`: initial pseudo time step, default is 1e-3. - `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`. @@ -380,16 +377,16 @@ end - `:pseudotransient`: Pseudo transient method. - `:secant`: Secant method for scalar equations. """ -@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm +@concrete struct SIAMFANLEquationsJL{L <: Union{Symbol, Nothing}} <: + AbstractNonlinearSolveAlgorithm method::Symbol - show_trace::Bool delta - linsolve::Union{Symbol, Nothing} + linsolve::L end -function SIAMFANLEquationsJL(; method = :newton, show_trace = false, delta = 1e-3, linsolve = nothing) +function SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothing) if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded") end - return SIAMFANLEquationsJL(method, show_trace, delta, linsolve) + return SIAMFANLEquationsJL(method, show_trace, delta, linsolve) end diff --git a/test/siamfanlequations.jl b/test/siamfanlequations.jl index ed35485b2..4b36dc312 100644 --- a/test/siamfanlequations.jl +++ b/test/siamfanlequations.jl @@ -74,7 +74,7 @@ end f_tol(u, p) = u^2 - 2 prob_tol = NonlinearProblem(f_tol, 1.0) for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11] - for method = [:newton, :pseudotransient, :secant] + for method in [:newton, :pseudotransient, :secant] sol = solve(prob_tol, SIAMFANLEquationsJL(method = method), abstol = tol) @test abs(sol.u[1] - sqrt(2)) < tol end @@ -141,7 +141,7 @@ f = NonlinearFunction(f!, jac = j!) p = A ProbN = NonlinearProblem(f, init, p) -for method = [:newton, :pseudotransient] +for method in [:newton, :pseudotransient] sol = solve(ProbN, SIAMFANLEquationsJL(method = method), reltol = 1e-8, abstol = 1e-8) end @@ -149,4 +149,4 @@ end init = ones(Complex{Float64}, 152); ProbN = NonlinearProblem(f, init, p) sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8) -=# \ No newline at end of file +=#