From fc437fd6372e7de2b8ac3975217899308e336949 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Tue, 19 Dec 2023 11:30:04 +0800 Subject: [PATCH 01/16] Add SIAMFANLEquations wrapper Signed-off-by: ErikQQY <2283984853@qq.com> --- Project.toml | 6 +- docs/src/api/fastlevenbergmarquardt.md | 2 +- docs/src/api/leastsquaresoptim.md | 2 +- docs/src/api/siamfanlequations.md | 17 +++ ext/NonlinearSolveSIAMFANLEquationsExt.jl | 176 ++++++++++++++++++++++ src/NonlinearSolve.jl | 2 +- src/extension_algs.jl | 32 ++++ test/runtests.jl | 1 + test/siamfanlequations.jl | 140 +++++++++++++++++ 9 files changed, 374 insertions(+), 4 deletions(-) create mode 100644 docs/src/api/siamfanlequations.md create mode 100644 ext/NonlinearSolveSIAMFANLEquationsExt.jl create mode 100644 test/siamfanlequations.jl diff --git a/Project.toml b/Project.toml index 8aacc9aff..0eb98cb77 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,7 @@ FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -44,6 +45,7 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" NonlinearSolveMINPACKExt = "MINPACK" NonlinearSolveNLsolveExt = "NLsolve" +NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations" NonlinearSolveSymbolicsExt = "Symbolics" NonlinearSolveZygoteExt = "Zygote" @@ -80,6 +82,7 @@ Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.11" SciMLOperators = "0.3.7" +SIAMFANLEquations = "1.0.1" SimpleNonlinearSolve = "1.0.2" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" @@ -109,6 +112,7 @@ NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -117,4 +121,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve"] +test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "SIAMFANLEquations"] diff --git a/docs/src/api/fastlevenbergmarquardt.md b/docs/src/api/fastlevenbergmarquardt.md index 8709dc303..423c70bbc 100644 --- a/docs/src/api/fastlevenbergmarquardt.md +++ b/docs/src/api/fastlevenbergmarquardt.md @@ -1,6 +1,6 @@ # FastLevenbergMarquardt.jl -This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML +This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML interface. Note that these solvers do not come by default, and thus one needs to install the package before using these solvers: diff --git a/docs/src/api/leastsquaresoptim.md b/docs/src/api/leastsquaresoptim.md index 76850555b..0581ee7c2 100644 --- a/docs/src/api/leastsquaresoptim.md +++ b/docs/src/api/leastsquaresoptim.md @@ -1,6 +1,6 @@ # LeastSquaresOptim.jl -This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML +This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML interface. Note that these solvers do not come by default, and thus one needs to install the package before using these solvers: diff --git a/docs/src/api/siamfanlequations.md b/docs/src/api/siamfanlequations.md new file mode 100644 index 000000000..2848a18ed --- /dev/null +++ b/docs/src/api/siamfanlequations.md @@ -0,0 +1,17 @@ +# SIAMFANLEquations.jl + +This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML +interface. Note that these solvers do not come by default, and thus one needs to install +the package before using these solvers: + +```julia +using Pkg +Pkg.add("SIAMFANLEquations") +using SIAMFANLEquations, NonlinearSolve +``` + +## Solver API + +```@docs +SIAMFANLEquationsJL +``` diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl new file mode 100644 index 000000000..3ab9c1005 --- /dev/null +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -0,0 +1,176 @@ +module NonlinearSolveSIAMFANLEquationsExt + +using NonlinearSolve, SciMLBase +using SIAMFANLEquations +import ConcreteStructs: @concrete +import UnPack: @unpack +import FiniteDiff, ForwardDiff + +function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8, + reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, kwargs...) + @unpack method, autodiff, show_trace, delta, linsolve = alg + + iip = SciMLBase.isinplace(prob) + if typeof(prob.u0) <: Number + f! = if iip + function (u) + du = similar(u) + prob.f(du, u, prob.p) + return du + end + else + u -> prob.f(u, prob.p) + end + + if method == :newton + res = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + elseif method == :pseudotransient + res = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace) + elseif method == :secant + res = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + end + + if res.errcode == 0 + retcode = ReturnCode.Success + elseif res.errcode == 10 + retcode = ReturnCode.MaxIters + elseif res.errcode == 1 + retcode = ReturnCode.Failure + @error("Line search failed") + elseif res.errcode == -1 + retcode = ReturnCode.Default + @info("Initial iterate satisfies the termination criteria") + end + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) + return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) + else + u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + end + + fu = NonlinearSolve.evaluate_f(prob, u) + + if iip + f! = function (du, u) + prob.f(du, u, prob.p) + return du + end + else + f! = function (du, u) + du .= prob.f(u, prob.p) + return du + end + end + + # Allocate ahead for function and Jacobian + N = length(u) + FS = zeros(eltype(u), N) + FPS = zeros(eltype(u), N, N) + # Allocate ahead for Krylov basis + + # Jacobian free Newton Krylov + if linsolve !== nothing + JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N) + # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between linear solvers + linsolve_alg = strip(repr(linsolve), ':') + + if method == :newton + res = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + elseif method == :pseudotransient + res = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + end + + if res.errcode == 0 + retcode = ReturnCode.Success + elseif res.errcode == 10 + retcode = ReturnCode.MaxIters + elseif res.errcode == 1 + retcode = ReturnCode.Failure + @error("Line search failed") + elseif res.errcode == -1 + retcode = ReturnCode.Default + @info("Initial iterate satisfies the termination criteria") + end + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) + return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) + end + + if prob.f.jac === nothing + use_forward_diff = if alg.autodiff === nothing + ForwardDiff.can_dual(eltype(u)) + else + alg.autodiff isa AutoForwardDiff + end + uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p) + if use_forward_diff + cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) : + ForwardDiff.JacobianConfig(uf, u) + else + cache = FiniteDiff.JacobianCache(u, fu) + end + J! = if iip + if use_forward_diff + fu_cache = similar(fu) + function (J, x, p) + uf.p = p + ForwardDiff.jacobian!(J, uf, fu_cache, x, cache) + return J + end + else + function (J, x, p) + uf.p = p + FiniteDiff.finite_difference_jacobian!(J, uf, x, cache) + return J + end + end + else + if use_forward_diff + function (J, x, p) + uf.p = p + ForwardDiff.jacobian!(J, uf, x, cache) + return J + end + else + function (J, x, p) + uf.p = p + J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache) + copyto!(J, J_) + return J + end + end + end + else + J! = prob.f.jac + end + + AJ!(J, u, x) = J!(J, x, prob.p) + + if method == :newton + res = nsol(f!, u, FS, FPS, AJ!; + sham=1, rtol = reltol, atol = abstol, maxit = maxiters, + printerr = show_trace) + elseif method == :pseudotransient + res = ptcsol(f!, u, FS, FPS, AJ!; + rtol = reltol, atol = abstol, maxit = maxiters, + delta0 = delta, printerr = show_trace) + + end + + if res.errcode == 0 + retcode = ReturnCode.Success + elseif res.errcode == 10 + retcode = ReturnCode.MaxIters + elseif res.errcode == 1 + retcode = ReturnCode.Failure + @error("Line search failed") + elseif res.errcode == -1 + retcode = ReturnCode.Default + @info("Initial iterate satisfies the termination criteria") + end + + + # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here. + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) + return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) +end + +end \ No newline at end of file diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index dd0a6cc33..07eb83c69 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -236,7 +236,7 @@ export RadiusUpdateSchemes export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient, Broyden, Klement, LimitedMemoryBroyden -export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL +export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, SIAMFANLEquationsJL export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg diff --git a/src/extension_algs.jl b/src/extension_algs.jl index b06414f0f..0f08ae1b2 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -206,3 +206,35 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) end + +""" + SIAMFANLEquationsJL(; method = :newton, autodiff = :central) + +### Keyword Arguments + + - `method`: the choice of method for solving the nonlinear system. + - `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or + central differencing via FiniteDiff.jl. The other choices are `:forward`. + - `show_trace`: whether to show the trace. + - `delta`: initial pseudo time step, default is 1e-3. + - `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`. + +### Submethod Choice + + - `:newton`: Classical Newton method. + - `:pseudotransient`: +""" +@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm + method::Symbol + autodiff::Symbol + show_trace::Bool + delta + linsolve::Union{Symbol, Nothing} +end + +function SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing) + if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing + error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded") + end + return SIAMFANLEquationsJL(method, autodiff, show_trace, delta, linsolve) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 2e74e905c..1d27a4b81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ end if GROUP == "All" || GROUP == "Wrappers" @time @safetestset "MINPACK" include("minpack.jl") @time @safetestset "NLsolve" include("nlsolve.jl") + @time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl") end if GROUP == "All" || GROUP == "23TestProblems" diff --git a/test/siamfanlequations.jl b/test/siamfanlequations.jl new file mode 100644 index 000000000..a903d8eb6 --- /dev/null +++ b/test/siamfanlequations.jl @@ -0,0 +1,140 @@ +using NonlinearSolve, SIAMFANLEquations, LinearAlgebra, Test + +# IIP Tests +function f_iip(du, u, p, t) + du[1] = 2 - 2u[1] + du[2] = u[1] - 4u[2] +end +u0 = zeros(2) +prob_iip = SteadyStateProblem(f_iip, u0) +abstol = 1e-8 + +for alg in [SIAMFANLEquationsJL()] + sol = solve(prob_iip, alg) + @test sol.retcode == ReturnCode.Success + p = nothing + + du = zeros(2) + f_iip(du, sol.u, nothing, 0) + @test maximum(du) < 1e-6 +end + +# OOP Tests +f_oop(u, p, t) = [2 - 2u[1], u[1] - 4u[2]] +u0 = zeros(2) +prob_oop = SteadyStateProblem(f_oop, u0) + +for alg in [SIAMFANLEquationsJL()] + sol = solve(prob_oop, alg) + @test sol.retcode == ReturnCode.Success + # test the solver is doing reasonable things for linear solve + # and that the stats are working properly + @test 1 <= sol.stats.nf < 10 + + du = zeros(2) + du = f_oop(sol.u, nothing, 0) + @test maximum(du) < 1e-6 +end + +# NonlinearProblem Tests + +function f_iip(du, u, p) + du[1] = 2 - 2u[1] + du[2] = u[1] - 4u[2] +end +u0 = zeros(2) +prob_iip = NonlinearProblem{true}(f_iip, u0) +abstol = 1e-8 +for alg in [SIAMFANLEquationsJL()] + local sol + sol = solve(prob_iip, alg) + @test sol.retcode == ReturnCode.Success + p = nothing + + du = zeros(2) + f_iip(du, sol.u, nothing) + @test maximum(du) < 1e-6 +end + +# OOP Tests +f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]] +u0 = zeros(2) +prob_oop = NonlinearProblem{false}(f_oop, u0) +for alg in [SIAMFANLEquationsJL()] + local sol + sol = solve(prob_oop, alg) + @test sol.retcode == ReturnCode.Success + + du = zeros(2) + du = f_oop(sol.u, nothing) + @test maximum(du) < 1e-6 +end + +# tolerance tests +f_tol(u, p) = u^2 - 2 +prob_tol = NonlinearProblem(f_tol, 1.0) +for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11] + sol = solve(prob_tol, SIAMFANLEquationsJL(), abstol = tol) + @test abs(sol.u[1] - sqrt(2)) < tol +end + +# Test the finite differencing technique +function f!(fvec, x, p) + fvec[1] = (x[1] + 3) * (x[2]^3 - 7) + 18 + fvec[2] = sin(x[2] * exp(x[1]) - 1) +end + +prob = NonlinearProblem{true}(f!, [0.1; 1.2]) +sol = solve(prob, SIAMFANLEquationsJL(autodiff = :central)) + +du = zeros(2) +f!(du, sol.u, nothing) +@test maximum(du) < 1e-6 + +# Test the autodiff technique +function f!(fvec, x, p) + fvec[1] = (x[1] + 3) * (x[2]^3 - 7) + 18 + fvec[2] = sin(x[2] * exp(x[1]) - 1) +end + +prob = NonlinearProblem{true}(f!, [0.1; 1.2]) +sol = solve(prob, SIAMFANLEquationsJL(autodiff = :forward)) + +du = zeros(2) +f!(du, sol.u, nothing) +@test maximum(du) < 1e-6 + +function problem(x, A) + return x .^ 2 - A +end + +function problemJacobian(x, A) + return diagm(2 .* x) +end + +function f!(F, u, p) + F[1:152] = problem(u, p) +end + +function j!(J, u, p) + J[1:152, 1:152] = problemJacobian(u, p) +end + +f = NonlinearFunction(f!) + +init = ones(152); +A = ones(152); +A[6] = 0.8 + +f = NonlinearFunction(f!, jac = j!) + +p = A + +ProbN = NonlinearProblem(f, init, p) +sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8) + +#= doesn't support complex numbers handling +init = ones(Complex{Float64}, 152); +ProbN = NonlinearProblem(f, init, p) +sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8) +=# \ No newline at end of file From 8e47b2d5b94a504dd28a41e8d81f9007910647b0 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:39:15 +0800 Subject: [PATCH 02/16] Lower version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0eb98cb77..4d6a4d2d8 100644 --- a/Project.toml +++ b/Project.toml @@ -82,7 +82,7 @@ Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.11" SciMLOperators = "0.3.7" -SIAMFANLEquations = "1.0.1" +SIAMFANLEquations = "1.0.0" SimpleNonlinearSolve = "1.0.2" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" From 702b76d322240629a8146129279dc73ea1a663fe Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:43:57 +0800 Subject: [PATCH 03/16] bump BandedMatrices --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4d6a4d2d8..b0bf077c7 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.6" -BandedMatrices = "1.3" +BandedMatrices = "1.4" BenchmarkTools = "1" ConcreteStructs = "0.2" DiffEqBase = "6.144" @@ -82,7 +82,7 @@ Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.11" SciMLOperators = "0.3.7" -SIAMFANLEquations = "1.0.0" +SIAMFANLEquations = "1.0" SimpleNonlinearSolve = "1.0.2" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" From e1eebc81bbaf97e05f50231d8e99694dca4e66e0 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:49:02 +0800 Subject: [PATCH 04/16] Fix BandedMatrices compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b0bf077c7..b72cd4262 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.6" -BandedMatrices = "1.4" +BandedMatrices = "1, 1.4" BenchmarkTools = "1" ConcreteStructs = "0.2" DiffEqBase = "6.144" From 7bd22e8d52d21c5135265059b6b3ead8585f4db3 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 19 Dec 2023 13:15:13 +0800 Subject: [PATCH 05/16] Lower BandedMatrices --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b72cd4262..7440d1b0f 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.6" -BandedMatrices = "1, 1.4" +BandedMatrices = "1" BenchmarkTools = "1" ConcreteStructs = "0.2" DiffEqBase = "6.144" From 37732aa392cebe5acd4d22c2ed656b3dd260e043 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Tue, 19 Dec 2023 23:01:06 +0800 Subject: [PATCH 06/16] Fix docs and remove error Signed-off-by: ErikQQY <2283984853@qq.com> --- docs/pages.jl | 3 ++- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 6 ------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/pages.jl b/docs/pages.jl index a8107bea2..55433367b 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -28,6 +28,7 @@ pages = ["index.md", "api/sundials.md", "api/steadystatediffeq.md", "api/leastsquaresoptim.md", - "api/fastlevenbergmarquardt.md"], + "api/fastlevenbergmarquardt.md", + "api/siamfanlequations.md"], "Release Notes" => "release_notes.md", ] diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 3ab9c1005..b70d52be4 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -36,10 +36,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg retcode = ReturnCode.MaxIters elseif res.errcode == 1 retcode = ReturnCode.Failure - @error("Line search failed") elseif res.errcode == -1 retcode = ReturnCode.Default - @info("Initial iterate satisfies the termination criteria") end stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) @@ -85,10 +83,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg retcode = ReturnCode.MaxIters elseif res.errcode == 1 retcode = ReturnCode.Failure - @error("Line search failed") elseif res.errcode == -1 retcode = ReturnCode.Default - @info("Initial iterate satisfies the termination criteria") end stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) @@ -161,10 +157,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg retcode = ReturnCode.MaxIters elseif res.errcode == 1 retcode = ReturnCode.Failure - @error("Line search failed") elseif res.errcode == -1 retcode = ReturnCode.Default - @info("Initial iterate satisfies the termination criteria") end From fff42cec74f9bd9702b530fa293f28b93f54f3a7 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Wed, 20 Dec 2023 14:28:45 +0800 Subject: [PATCH 07/16] Fix issues in comments Signed-off-by: ErikQQY <2283984853@qq.com> --- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 4 ++-- src/extension_algs.jl | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index b70d52be4..6f87a0c47 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -68,8 +68,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg # Jacobian free Newton Krylov if linsolve !== nothing JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N) - # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between linear solvers - linsolve_alg = strip(repr(linsolve), ':') + # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers + linsolve_alg = String(linsolve) if method == :newton res = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 0f08ae1b2..b0a0f5efe 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -208,7 +208,7 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = end """ - SIAMFANLEquationsJL(; method = :newton, autodiff = :central) + SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing) ### Keyword Arguments @@ -222,7 +222,8 @@ end ### Submethod Choice - `:newton`: Classical Newton method. - - `:pseudotransient`: + - `:pseudotransient`: Pseudo transient method. + - `:secant`: Secant method for scalar equations. """ @concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm method::Symbol From c5d603c96ea0527685470e760f337882ad687ad7 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Wed, 20 Dec 2023 19:41:13 +0800 Subject: [PATCH 08/16] bump BandedMatrices Signed-off-by: ErikQQY <2283984853@qq.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7440d1b0f..b0bf077c7 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.6" -BandedMatrices = "1" +BandedMatrices = "1.4" BenchmarkTools = "1" ConcreteStructs = "0.2" DiffEqBase = "6.144" From e8465576514c1372ceee7ff8631267b5d3bb943d Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sat, 23 Dec 2023 14:01:59 +0800 Subject: [PATCH 09/16] Fix compat error Signed-off-by: ErikQQY <2283984853@qq.com> --- Project.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 558b72387..6d89f5df9 100644 --- a/Project.toml +++ b/Project.toml @@ -53,8 +53,8 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.6" -BandedMatrices = "1.4" -BenchmarkTools = "1" +BandedMatrices = "0.17, 1.4" +BenchmarkTools = "1.4" ConcreteStructs = "0.2" DiffEqBase = "6.144" EnumX = "1" @@ -67,7 +67,7 @@ LazyArrays = "1.8.2" LeastSquaresOptim = "0.8.5" LineSearches = "7.2" LinearAlgebra = "<0.0.1, 1" -LinearSolve = "2.21" +LinearSolve = "2" MINPACK = "1.2" MaybeInplace = "0.1.1" NLsolve = "4.5" @@ -83,7 +83,7 @@ Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.11" SciMLOperators = "0.3.7" -SIAMFANLEquations = "1.0" +SIAMFANLEquations = "1.0.0" SimpleNonlinearSolve = "1.0.2" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" From bff7b31d19c33b1894c29d09e50dd84a670788d5 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sat, 23 Dec 2023 23:43:43 +0800 Subject: [PATCH 10/16] Fix compat errors Signed-off-by: ErikQQY <2283984853@qq.com> --- Project.toml | 6 +++--- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 6d89f5df9..99b17fd10 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ NonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.5" Aqua = "0.8" ArrayInterface = "7.6" -BandedMatrices = "0.17, 1.4" +BandedMatrices = "1.4" BenchmarkTools = "1.4" ConcreteStructs = "0.2" DiffEqBase = "6.144" @@ -67,7 +67,7 @@ LazyArrays = "1.8.2" LeastSquaresOptim = "0.8.5" LineSearches = "7.2" LinearAlgebra = "<0.0.1, 1" -LinearSolve = "2" +LinearSolve = "2.21" MINPACK = "1.2" MaybeInplace = "0.1.1" NLsolve = "4.5" @@ -83,7 +83,7 @@ Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.11" SciMLOperators = "0.3.7" -SIAMFANLEquations = "1.0.0" +SIAMFANLEquations = "1.0.1" SimpleNonlinearSolve = "1.0.2" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 6f87a0c47..1e7f4405a 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -7,7 +7,9 @@ import UnPack: @unpack import FiniteDiff, ForwardDiff function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8, - reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, kwargs...) + reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...) + @assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!" + @unpack method, autodiff, show_trace, delta, linsolve = alg iip = SciMLBase.isinplace(prob) @@ -39,7 +41,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg elseif res.errcode == -1 retcode = ReturnCode.Default end - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1])) return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) else u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) @@ -86,7 +88,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg elseif res.errcode == -1 retcode = ReturnCode.Default end - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1])) return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) end @@ -163,7 +165,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here. - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1])) + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1])) return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) end From ce5226079d9a14f329f937833ca729684425dfc7 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 24 Dec 2023 00:43:56 +0800 Subject: [PATCH 11/16] Fix stats Signed-off-by: ErikQQY <2283984853@qq.com> --- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 53 +++++++++++------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 1e7f4405a..58d70976e 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -25,24 +25,24 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg end if method == :newton - res = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) elseif method == :pseudotransient - res = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace) + sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace) elseif method == :secant - res = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) end - if res.errcode == 0 + if sol.errcode == 0 retcode = ReturnCode.Success - elseif res.errcode == 10 + elseif sol.errcode == 10 retcode = ReturnCode.MaxIters - elseif res.errcode == 1 + elseif sol.errcode == 1 retcode = ReturnCode.Failure - elseif res.errcode == -1 + elseif sol.errcode == -1 retcode = ReturnCode.Default end - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1])) - return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) else u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end @@ -74,22 +74,22 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg linsolve_alg = String(linsolve) if method == :newton - res = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) elseif method == :pseudotransient - res = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) end - if res.errcode == 0 + if sol.errcode == 0 retcode = ReturnCode.Success - elseif res.errcode == 10 + elseif sol.errcode == 10 retcode = ReturnCode.MaxIters - elseif res.errcode == 1 + elseif sol.errcode == 1 retcode = ReturnCode.Failure - elseif res.errcode == -1 + elseif sol.errcode == -1 retcode = ReturnCode.Default end - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1])) - return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) end if prob.f.jac === nothing @@ -143,30 +143,29 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg AJ!(J, u, x) = J!(J, x, prob.p) if method == :newton - res = nsol(f!, u, FS, FPS, AJ!; + sol = nsol(f!, u, FS, FPS, AJ!; sham=1, rtol = reltol, atol = abstol, maxit = maxiters, printerr = show_trace) elseif method == :pseudotransient - res = ptcsol(f!, u, FS, FPS, AJ!; + sol = ptcsol(f!, u, FS, FPS, AJ!; rtol = reltol, atol = abstol, maxit = maxiters, delta0 = delta, printerr = show_trace) - end - if res.errcode == 0 + if sol.errcode == 0 retcode = ReturnCode.Success - elseif res.errcode == 10 + elseif sol.errcode == 10 retcode = ReturnCode.MaxIters - elseif res.errcode == 1 + elseif sol.errcode == 1 retcode = ReturnCode.Failure - elseif res.errcode == -1 + elseif sol.errcode == -1 retcode = ReturnCode.Default end - # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here. - stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1])) - return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats) + stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) + println(sol.stats) + return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) end end \ No newline at end of file From 8ee980b37c42fb8a855d7b58ffe41ad69f5885d2 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 24 Dec 2023 11:48:58 +0800 Subject: [PATCH 12/16] Fix abstol, reltol and use the built-in Jacobian Signed-off-by: ErikQQY <2283984853@qq.com> --- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 106 ++++++++-------------- 1 file changed, 38 insertions(+), 68 deletions(-) diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 58d70976e..403fa8d05 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -6,14 +6,19 @@ import ConcreteStructs: @concrete import UnPack: @unpack import FiniteDiff, ForwardDiff -function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8, - reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing, + reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...) @assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!" @unpack method, autodiff, show_trace, delta, linsolve = alg iip = SciMLBase.isinplace(prob) - if typeof(prob.u0) <: Number + T = eltype(u0) + + atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol + + if prob.u0 isa Number f! = if iip function (u) du = similar(u) @@ -25,11 +30,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg end if method == :newton - sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) elseif method == :pseudotransient - sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace) + sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace) elseif method == :secant - sol = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) end if sol.errcode == 0 @@ -61,22 +66,21 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg end end - # Allocate ahead for function and Jacobian + # Allocate ahead for function N = length(u) - FS = zeros(eltype(u), N) - FPS = zeros(eltype(u), N, N) - # Allocate ahead for Krylov basis + FS = zeros(T, N) # Jacobian free Newton Krylov if linsolve !== nothing - JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N) + # Allocate ahead for Krylov basis + JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N) # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers linsolve_alg = String(linsolve) if method == :newton - sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) elseif method == :pseudotransient - sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace) + sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace) end if sol.errcode == 0 @@ -92,64 +96,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) end + # Allocate ahead for Jacobian + FPS = zeros(T, N, N) if prob.f.jac === nothing - use_forward_diff = if alg.autodiff === nothing - ForwardDiff.can_dual(eltype(u)) - else - alg.autodiff isa AutoForwardDiff - end - uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p) - if use_forward_diff - cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) : - ForwardDiff.JacobianConfig(uf, u) - else - cache = FiniteDiff.JacobianCache(u, fu) - end - J! = if iip - if use_forward_diff - fu_cache = similar(fu) - function (J, x, p) - uf.p = p - ForwardDiff.jacobian!(J, uf, fu_cache, x, cache) - return J - end - else - function (J, x, p) - uf.p = p - FiniteDiff.finite_difference_jacobian!(J, uf, x, cache) - return J - end - end - else - if use_forward_diff - function (J, x, p) - uf.p = p - ForwardDiff.jacobian!(J, uf, x, cache) - return J - end - else - function (J, x, p) - uf.p = p - J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache) - copyto!(J, J_) - return J - end - end + # Use the built-in Jacobian machinery + if method == :newton + sol = nsol(f!, u, FS, FPS; + sham=1, atol = atol, rtol = rtol, maxit = maxiters, + printerr = show_trace) + elseif method == :pseudotransient + sol = ptcsol(f!, u, FS, FPS; + atol = atol, rtol = rtol, maxit = maxiters, + delta0 = delta, printerr = show_trace) end else - J! = prob.f.jac - end - - AJ!(J, u, x) = J!(J, x, prob.p) - - if method == :newton - sol = nsol(f!, u, FS, FPS, AJ!; - sham=1, rtol = reltol, atol = abstol, maxit = maxiters, - printerr = show_trace) - elseif method == :pseudotransient - sol = ptcsol(f!, u, FS, FPS, AJ!; - rtol = reltol, atol = abstol, maxit = maxiters, - delta0 = delta, printerr = show_trace) + AJ!(J, u, x) = prob.f.jac(J, x, prob.p) + if method == :newton + sol = nsol(f!, u, FS, FPS, AJ!; + sham=1, atol = atol, rtol = rtol, maxit = maxiters, + printerr = show_trace) + elseif method == :pseudotransient + sol = ptcsol(f!, u, FS, FPS, AJ!; + atol = atol, rtol = rtol, maxit = maxiters, + delta0 = delta, printerr = show_trace) + end end if sol.errcode == 0 From ce7a3c7ae2de122e94423d906cc756d110b3c14b Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 24 Dec 2023 12:02:03 +0800 Subject: [PATCH 13/16] typo Signed-off-by: ErikQQY <2283984853@qq.com> --- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 403fa8d05..c24aad0ac 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -13,7 +13,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg @unpack method, autodiff, show_trace, delta, linsolve = alg iip = SciMLBase.isinplace(prob) - T = eltype(u0) + T = eltype(prob.u0) atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol From 80d2a19e43b6b6a75add5b065512abb03d6d8e40 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 24 Dec 2023 12:15:02 +0800 Subject: [PATCH 14/16] delete print Signed-off-by: ErikQQY <2283984853@qq.com> --- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index c24aad0ac..25fdd327e 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -134,7 +134,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here. stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm))) - println(sol.stats) return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol) end From c0ece9dd9e85b26a889c5e07a4a514d4bdc723f3 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Sun, 24 Dec 2023 12:37:36 +0800 Subject: [PATCH 15/16] Update pages.jl --- docs/pages.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pages.jl b/docs/pages.jl index fa6e90dea..9c148bcb4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -32,6 +32,6 @@ pages = ["index.md", "api/fastlevenbergmarquardt.md", "api/speedmapping.md", "api/fixedpointacceleration.md", - "api/siamfanlequations.md"]], + "api/siamfanlequations.md"], "Release Notes" => "release_notes.md", ] From e2de813eab07e5ba8f71ae53cc93dcc68389cc1b Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 24 Dec 2023 15:45:34 +0800 Subject: [PATCH 16/16] Cover more tests Signed-off-by: ErikQQY <2283984853@qq.com> --- ext/NonlinearSolveSIAMFANLEquationsExt.jl | 5 +---- src/extension_algs.jl | 7 ++----- test/siamfanlequations.jl | 22 +++++++++++++++++----- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/ext/NonlinearSolveSIAMFANLEquationsExt.jl b/ext/NonlinearSolveSIAMFANLEquationsExt.jl index 25fdd327e..47ecc96c9 100644 --- a/ext/NonlinearSolveSIAMFANLEquationsExt.jl +++ b/ext/NonlinearSolveSIAMFANLEquationsExt.jl @@ -4,13 +4,12 @@ using NonlinearSolve, SciMLBase using SIAMFANLEquations import ConcreteStructs: @concrete import UnPack: @unpack -import FiniteDiff, ForwardDiff function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing, reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...) @assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!" - @unpack method, autodiff, show_trace, delta, linsolve = alg + @unpack method, show_trace, delta, linsolve = alg iip = SciMLBase.isinplace(prob) T = eltype(prob.u0) @@ -52,8 +51,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end - fu = NonlinearSolve.evaluate_f(prob, u) - if iip f! = function (du, u) prob.f(du, u, prob.p) diff --git a/src/extension_algs.jl b/src/extension_algs.jl index af0914dba..ea92454b0 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -331,8 +331,6 @@ end ### Keyword Arguments - `method`: the choice of method for solving the nonlinear system. - - `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or - central differencing via FiniteDiff.jl. The other choices are `:forward`. - `show_trace`: whether to show the trace. - `delta`: initial pseudo time step, default is 1e-3. - `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`. @@ -345,15 +343,14 @@ end """ @concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm method::Symbol - autodiff::Symbol show_trace::Bool delta linsolve::Union{Symbol, Nothing} end -function SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing) +function SIAMFANLEquationsJL(; method = :newton, show_trace = false, delta = 1e-3, linsolve = nothing) if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded") end - return SIAMFANLEquationsJL(method, autodiff, show_trace, delta, linsolve) + return SIAMFANLEquationsJL(method, show_trace, delta, linsolve) end diff --git a/test/siamfanlequations.jl b/test/siamfanlequations.jl index a903d8eb6..ed35485b2 100644 --- a/test/siamfanlequations.jl +++ b/test/siamfanlequations.jl @@ -70,11 +70,21 @@ for alg in [SIAMFANLEquationsJL()] @test maximum(du) < 1e-6 end -# tolerance tests +# tolerance tests for scalar equation solvers f_tol(u, p) = u^2 - 2 prob_tol = NonlinearProblem(f_tol, 1.0) for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11] - sol = solve(prob_tol, SIAMFANLEquationsJL(), abstol = tol) + for method = [:newton, :pseudotransient, :secant] + sol = solve(prob_tol, SIAMFANLEquationsJL(method = method), abstol = tol) + @test abs(sol.u[1] - sqrt(2)) < tol + end +end + +# Test the JFNK technique +f_jfnk(u, p) = u^2 - 2 +prob_jfnk = NonlinearProblem(f_jfnk, 1.0) +for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11] + sol = solve(prob_jfnk, SIAMFANLEquationsJL(linsolve = :gmres), abstol = tol) @test abs(sol.u[1] - sqrt(2)) < tol end @@ -85,7 +95,7 @@ function f!(fvec, x, p) end prob = NonlinearProblem{true}(f!, [0.1; 1.2]) -sol = solve(prob, SIAMFANLEquationsJL(autodiff = :central)) +sol = solve(prob, SIAMFANLEquationsJL()) du = zeros(2) f!(du, sol.u, nothing) @@ -98,7 +108,7 @@ function f!(fvec, x, p) end prob = NonlinearProblem{true}(f!, [0.1; 1.2]) -sol = solve(prob, SIAMFANLEquationsJL(autodiff = :forward)) +sol = solve(prob, SIAMFANLEquationsJL()) du = zeros(2) f!(du, sol.u, nothing) @@ -131,7 +141,9 @@ f = NonlinearFunction(f!, jac = j!) p = A ProbN = NonlinearProblem(f, init, p) -sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8) +for method = [:newton, :pseudotransient] + sol = solve(ProbN, SIAMFANLEquationsJL(method = method), reltol = 1e-8, abstol = 1e-8) +end #= doesn't support complex numbers handling init = ones(Complex{Float64}, 152);