diff --git a/Project.toml b/Project.toml index 2a647bd58..6a359e6fa 100644 --- a/Project.toml +++ b/Project.toml @@ -36,6 +36,7 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -47,6 +48,7 @@ NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" NonlinearSolveMINPACKExt = "MINPACK" NonlinearSolveNLsolveExt = "NLsolve" +NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations" NonlinearSolveSpeedMappingExt = "SpeedMapping" NonlinearSolveSymbolicsExt = "Symbolics" NonlinearSolveZygoteExt = "Zygote" @@ -55,8 +57,8 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.7" -BandedMatrices = "1.3" -BenchmarkTools = "1" +BandedMatrices = "1.4" +BenchmarkTools = "1.4" ConcreteStructs = "0.2" DiffEqBase = "6.144" EnumX = "1" @@ -86,6 +88,7 @@ Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.11" SciMLOperators = "0.3.7" +SIAMFANLEquations = "1.0.1" SimpleNonlinearSolve = "1.0.2" SparseArrays = "1.9" SparseDiffTools = "2.14" @@ -118,6 +121,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -127,4 +131,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration"] +test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations"] diff --git a/docs/pages.jl b/docs/pages.jl index c3c0e164a..9c148bcb4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -31,6 +31,7 @@ pages = ["index.md", "api/leastsquaresoptim.md", "api/fastlevenbergmarquardt.md", "api/speedmapping.md", - "api/fixedpointacceleration.md"], + "api/fixedpointacceleration.md", + "api/siamfanlequations.md"], "Release Notes" => "release_notes.md", ] diff --git a/docs/src/api/fastlevenbergmarquardt.md b/docs/src/api/fastlevenbergmarquardt.md index 8709dc303..423c70bbc 100644 --- a/docs/src/api/fastlevenbergmarquardt.md +++ b/docs/src/api/fastlevenbergmarquardt.md @@ -1,6 +1,6 @@ # FastLevenbergMarquardt.jl -This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML +This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML interface. Note that these solvers do not come by default, and thus one needs to install the package before using these solvers: diff --git a/docs/src/api/leastsquaresoptim.md b/docs/src/api/leastsquaresoptim.md index 76850555b..0581ee7c2 100644 --- a/docs/src/api/leastsquaresoptim.md +++ b/docs/src/api/leastsquaresoptim.md @@ -1,6 +1,6 @@ # LeastSquaresOptim.jl -This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML +This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML interface. Note that these solvers do not come by default, and thus one needs to install the package before using these solvers: diff --git a/docs/src/api/siamfanlequations.md b/docs/src/api/siamfanlequations.md new file mode 100644 index 000000000..2848a18ed --- /dev/null +++ b/docs/src/api/siamfanlequations.md @@ -0,0 +1,17 @@ +# SIAMFANLEquations.jl + +This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML +interface. Note that these solvers do not come by default, and thus one needs to install +the package before using these solvers: + +```julia +using Pkg +Pkg.add("SIAMFANLEquations") +using SIAMFANLEquations, NonlinearSolve +``` + +## Solver API + +```@docs +SIAMFANLEquationsJL +``` diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl new file mode 100644 index 000000000..47ecc96c9 --- /dev/null +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -0,0 +1,137 @@ +module NonlinearSolveSIAMFANLEquationsExt + +using NonlinearSolve, SciMLBase +using SIAMFANLEquations +import ConcreteStructs: @concrete +import UnPack: @unpack + +function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing, + reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...) + @assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!" + + @unpack method, show_trace, delta, linsolve = alg + + iip = SciMLBase.isinplace(prob) + T = eltype(prob.u0) + + atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol + + if prob.u0 isa Number + f! = if iip + function (u) + du = similar(u) + prob.f(du, u, prob.p) + return du + end + else + u -> prob.f(u, prob.p) + end + + if method == :newton + sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + elseif method == :pseudotransient + sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace) + elseif method == :secant + sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + end + + if sol.errcode == 0 + retcode = ReturnCode.Success + elseif sol.errcode == 10 + retcode = ReturnCode.MaxIters + elseif sol.errcode == 1 + retcode = ReturnCode.Failure + elseif sol.errcode == -1 + retcode = ReturnCode.Default + end + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) + else + u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + end + + if iip + f! = function (du, u) + prob.f(du, u, prob.p) + return du + end + else + f! = function (du, u) + du .= prob.f(u, prob.p) + return du + end + end + + # Allocate ahead for function + N = length(u) + FS = zeros(T, N) + + # Jacobian free Newton Krylov + if linsolve !== nothing + # Allocate ahead for Krylov basis + JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N) + # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers + linsolve_alg = String(linsolve) + + if method == :newton + sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + elseif method == :pseudotransient + sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) + end + + if sol.errcode == 0 + retcode = ReturnCode.Success + elseif sol.errcode == 10 + retcode = ReturnCode.MaxIters + elseif sol.errcode == 1 + retcode = ReturnCode.Failure + elseif sol.errcode == -1 + retcode = ReturnCode.Default + end + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) + end + + # Allocate ahead for Jacobian + FPS = zeros(T, N, N) + if prob.f.jac === nothing + # Use the built-in Jacobian machinery + if method == :newton + sol = nsol(f!, u, FS, FPS; + sham=1, atol = atol, rtol = rtol, maxit = maxiters, + printerr = show_trace) + elseif method == :pseudotransient + sol = ptcsol(f!, u, FS, FPS; + atol = atol, rtol = rtol, maxit = maxiters, + delta0 = delta, printerr = show_trace) + end + else + AJ!(J, u, x) = prob.f.jac(J, x, prob.p) + if method == :newton + sol = nsol(f!, u, FS, FPS, AJ!; + sham=1, atol = atol, rtol = rtol, maxit = maxiters, + printerr = show_trace) + elseif method == :pseudotransient + sol = ptcsol(f!, u, FS, FPS, AJ!; + atol = atol, rtol = rtol, maxit = maxiters, + delta0 = delta, printerr = show_trace) + end + end + + if sol.errcode == 0 + retcode = ReturnCode.Success + elseif sol.errcode == 10 + retcode = ReturnCode.MaxIters + elseif sol.errcode == 1 + retcode = ReturnCode.Failure + elseif sol.errcode == -1 + retcode = ReturnCode.Default + end + + # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here. + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) +end + +end \ No newline at end of file diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 4c602991c..bf77135c8 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -237,7 +237,7 @@ export RadiusUpdateSchemes export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient, Broyden, Klement, LimitedMemoryBroyden export LeastSquaresOptimJL, - FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL + FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 4fcad43ad..ea92454b0 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -208,6 +208,7 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = end """ + SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false, orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000) @@ -322,3 +323,34 @@ function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing, return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids, dampening, m, condition_number_threshold) end + +""" + + SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing) + +### Keyword Arguments + + - `method`: the choice of method for solving the nonlinear system. + - `show_trace`: whether to show the trace. + - `delta`: initial pseudo time step, default is 1e-3. + - `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`. + +### Submethod Choice + + - `:newton`: Classical Newton method. + - `:pseudotransient`: Pseudo transient method. + - `:secant`: Secant method for scalar equations. +""" +@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm + method::Symbol + show_trace::Bool + delta + linsolve::Union{Symbol, Nothing} +end + +function SIAMFANLEquationsJL(; method = :newton, show_trace = false, delta = 1e-3, linsolve = nothing) + if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing + error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded") + end + return SIAMFANLEquationsJL(method, show_trace, delta, linsolve) +end diff --git a/test/runtests.jl b/test/runtests.jl index f761b08ef..4997eb43f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ end if GROUP == "All" || GROUP == "Wrappers" @time @safetestset "MINPACK" include("minpack.jl") @time @safetestset "NLsolve" include("nlsolve.jl") + @time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl") @time @safetestset "SpeedMapping" include("speedmapping.jl") @time @safetestset "FixedPointAcceleration" include("fixed_point_acceleration.jl") end diff --git a/test/siamfanlequations.jl b/test/siamfanlequations.jl new file mode 100644 index 000000000..ed35485b2 --- /dev/null +++ b/test/siamfanlequations.jl @@ -0,0 +1,152 @@ +using NonlinearSolve, SIAMFANLEquations, LinearAlgebra, Test + +# IIP Tests +function f_iip(du, u, p, t) + du[1] = 2 - 2u[1] + du[2] = u[1] - 4u[2] +end +u0 = zeros(2) +prob_iip = SteadyStateProblem(f_iip, u0) +abstol = 1e-8 + +for alg in [SIAMFANLEquationsJL()] + sol = solve(prob_iip, alg) + @test sol.retcode == ReturnCode.Success + p = nothing + + du = zeros(2) + f_iip(du, sol.u, nothing, 0) + @test maximum(du) < 1e-6 +end + +# OOP Tests +f_oop(u, p, t) = [2 - 2u[1], u[1] - 4u[2]] +u0 = zeros(2) +prob_oop = SteadyStateProblem(f_oop, u0) + +for alg in [SIAMFANLEquationsJL()] + sol = solve(prob_oop, alg) + @test sol.retcode == ReturnCode.Success + # test the solver is doing reasonable things for linear solve + # and that the stats are working properly + @test 1 <= sol.stats.nf < 10 + + du = zeros(2) + du = f_oop(sol.u, nothing, 0) + @test maximum(du) < 1e-6 +end + +# NonlinearProblem Tests + +function f_iip(du, u, p) + du[1] = 2 - 2u[1] + du[2] = u[1] - 4u[2] +end +u0 = zeros(2) +prob_iip = NonlinearProblem{true}(f_iip, u0) +abstol = 1e-8 +for alg in [SIAMFANLEquationsJL()] + local sol + sol = solve(prob_iip, alg) + @test sol.retcode == ReturnCode.Success + p = nothing + + du = zeros(2) + f_iip(du, sol.u, nothing) + @test maximum(du) < 1e-6 +end + +# OOP Tests +f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]] +u0 = zeros(2) +prob_oop = NonlinearProblem{false}(f_oop, u0) +for alg in [SIAMFANLEquationsJL()] + local sol + sol = solve(prob_oop, alg) + @test sol.retcode == ReturnCode.Success + + du = zeros(2) + du = f_oop(sol.u, nothing) + @test maximum(du) < 1e-6 +end + +# tolerance tests for scalar equation solvers +f_tol(u, p) = u^2 - 2 +prob_tol = NonlinearProblem(f_tol, 1.0) +for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11] + for method = [:newton, :pseudotransient, :secant] + sol = solve(prob_tol, SIAMFANLEquationsJL(method = method), abstol = tol) + @test abs(sol.u[1] - sqrt(2)) < tol + end +end + +# Test the JFNK technique +f_jfnk(u, p) = u^2 - 2 +prob_jfnk = NonlinearProblem(f_jfnk, 1.0) +for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11] + sol = solve(prob_jfnk, SIAMFANLEquationsJL(linsolve = :gmres), abstol = tol) + @test abs(sol.u[1] - sqrt(2)) < tol +end + +# Test the finite differencing technique +function f!(fvec, x, p) + fvec[1] = (x[1] + 3) * (x[2]^3 - 7) + 18 + fvec[2] = sin(x[2] * exp(x[1]) - 1) +end + +prob = NonlinearProblem{true}(f!, [0.1; 1.2]) +sol = solve(prob, SIAMFANLEquationsJL()) + +du = zeros(2) +f!(du, sol.u, nothing) +@test maximum(du) < 1e-6 + +# Test the autodiff technique +function f!(fvec, x, p) + fvec[1] = (x[1] + 3) * (x[2]^3 - 7) + 18 + fvec[2] = sin(x[2] * exp(x[1]) - 1) +end + +prob = NonlinearProblem{true}(f!, [0.1; 1.2]) +sol = solve(prob, SIAMFANLEquationsJL()) + +du = zeros(2) +f!(du, sol.u, nothing) +@test maximum(du) < 1e-6 + +function problem(x, A) + return x .^ 2 - A +end + +function problemJacobian(x, A) + return diagm(2 .* x) +end + +function f!(F, u, p) + F[1:152] = problem(u, p) +end + +function j!(J, u, p) + J[1:152, 1:152] = problemJacobian(u, p) +end + +f = NonlinearFunction(f!) + +init = ones(152); +A = ones(152); +A[6] = 0.8 + +f = NonlinearFunction(f!, jac = j!) + +p = A + +ProbN = NonlinearProblem(f, init, p) +for method = [:newton, :pseudotransient] + sol = solve(ProbN, SIAMFANLEquationsJL(method = method), reltol = 1e-8, abstol = 1e-8) +end + +#= doesn't support complex numbers handling +init = ones(Complex{Float64}, 152); +ProbN = NonlinearProblem(f, init, p) +sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8) +=# \ No newline at end of file