Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SIAMFANLEquations wrapper #333

Merged
merged 19 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -44,14 +45,15 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLsolveExt = "NLsolve"
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSymbolicsExt = "Symbolics"
NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.5"
Aqua = "0.8"
ArrayInterface = "7.6"
BandedMatrices = "1.3"
BandedMatrices = "1"
BenchmarkTools = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.144"
Expand Down Expand Up @@ -80,6 +82,7 @@ Reexport = "1.2"
SafeTestsets = "0.1"
SciMLBase = "2.11"
SciMLOperators = "0.3.7"
SIAMFANLEquations = "1.0"
SimpleNonlinearSolve = "1.0.2"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.14"
Expand Down Expand Up @@ -109,6 +112,7 @@ NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -117,4 +121,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve"]
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "SIAMFANLEquations"]
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pages = ["index.md",
"api/sundials.md",
"api/steadystatediffeq.md",
"api/leastsquaresoptim.md",
"api/fastlevenbergmarquardt.md"],
"api/fastlevenbergmarquardt.md",
"api/siamfanlequations.md"],
"Release Notes" => "release_notes.md",
]
2 changes: 1 addition & 1 deletion docs/src/api/fastlevenbergmarquardt.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FastLevenbergMarquardt.jl

This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/leastsquaresoptim.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# LeastSquaresOptim.jl

This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML
This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
17 changes: 17 additions & 0 deletions docs/src/api/siamfanlequations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SIAMFANLEquations.jl

This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

```julia
using Pkg
Pkg.add("SIAMFANLEquations")
using SIAMFANLEquations, NonlinearSolve
```

## Solver API

```@docs
SIAMFANLEquationsJL
```
170 changes: 170 additions & 0 deletions ext/NonlinearSolveSIAMFANLEquationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
module NonlinearSolveSIAMFANLEquationsExt

using NonlinearSolve, SciMLBase
using SIAMFANLEquations
import ConcreteStructs: @concrete
import UnPack: @unpack
import FiniteDiff, ForwardDiff

function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8,

Check warning on line 9 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L9

Added line #L9 was not covered by tests
reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, kwargs...)
@unpack method, autodiff, show_trace, delta, linsolve = alg

Check warning on line 11 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L11

Added line #L11 was not covered by tests

iip = SciMLBase.isinplace(prob)
if typeof(prob.u0) <: Number
f! = if iip
function (u)
du = similar(u)
prob.f(du, u, prob.p)
return du

Check warning on line 19 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L13-L19

Added lines #L13 - L19 were not covered by tests
end
else
u -> prob.f(u, prob.p)

Check warning on line 22 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L22

Added line #L22 was not covered by tests
end

if method == :newton
res = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
elseif method == :pseudotransient
res = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace)
elseif method == :secant
res = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)

Check warning on line 30 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L25-L30

Added lines #L25 - L30 were not covered by tests
end

if res.errcode == 0
retcode = ReturnCode.Success
elseif res.errcode == 10
retcode = ReturnCode.MaxIters
elseif res.errcode == 1
retcode = ReturnCode.Failure
elseif res.errcode == -1
retcode = ReturnCode.Default

Check warning on line 40 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L33-L40

Added lines #L33 - L40 were not covered by tests
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1]))
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)

Check warning on line 43 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
else
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)

Check warning on line 45 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L45

Added line #L45 was not covered by tests
end

fu = NonlinearSolve.evaluate_f(prob, u)

Check warning on line 48 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L48

Added line #L48 was not covered by tests

if iip
f! = function (du, u)
prob.f(du, u, prob.p)
return du

Check warning on line 53 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L50-L53

Added lines #L50 - L53 were not covered by tests
end
else
f! = function (du, u)
du .= prob.f(u, prob.p)
return du

Check warning on line 58 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L56-L58

Added lines #L56 - L58 were not covered by tests
end
end

# Allocate ahead for function and Jacobian
N = length(u)
FS = zeros(eltype(u), N)
FPS = zeros(eltype(u), N, N)

Check warning on line 65 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L63-L65

Added lines #L63 - L65 were not covered by tests
# Allocate ahead for Krylov basis

# Jacobian free Newton Krylov
if linsolve !== nothing
JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N)

Check warning on line 70 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between linear solvers
linsolve_alg = strip(repr(linsolve), ':')

Check warning on line 72 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L72

Added line #L72 was not covered by tests
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

if method == :newton
res = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
elseif method == :pseudotransient
res = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)

Check warning on line 77 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L74-L77

Added lines #L74 - L77 were not covered by tests
end

if res.errcode == 0
retcode = ReturnCode.Success
elseif res.errcode == 10
retcode = ReturnCode.MaxIters
elseif res.errcode == 1
retcode = ReturnCode.Failure
elseif res.errcode == -1
retcode = ReturnCode.Default

Check warning on line 87 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L80-L87

Added lines #L80 - L87 were not covered by tests
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1]))
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)

Check warning on line 90 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L89-L90

Added lines #L89 - L90 were not covered by tests
end

if prob.f.jac === nothing
use_forward_diff = if alg.autodiff === nothing
ForwardDiff.can_dual(eltype(u))

Check warning on line 95 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L93-L95

Added lines #L93 - L95 were not covered by tests
else
alg.autodiff isa AutoForwardDiff

Check warning on line 97 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L97

Added line #L97 was not covered by tests
end
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
if use_forward_diff
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :

Check warning on line 101 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L99-L101

Added lines #L99 - L101 were not covered by tests
ForwardDiff.JacobianConfig(uf, u)
else
cache = FiniteDiff.JacobianCache(u, fu)

Check warning on line 104 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L104

Added line #L104 was not covered by tests
end
J! = if iip
if use_forward_diff
fu_cache = similar(fu)
function (J, x, p)
uf.p = p
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
return J

Check warning on line 112 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L106-L112

Added lines #L106 - L112 were not covered by tests
end
else
function (J, x, p)
uf.p = p
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return J

Check warning on line 118 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L115-L118

Added lines #L115 - L118 were not covered by tests
end
end
else
if use_forward_diff
function (J, x, p)
uf.p = p
ForwardDiff.jacobian!(J, uf, x, cache)
return J

Check warning on line 126 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L122-L126

Added lines #L122 - L126 were not covered by tests
end
else
function (J, x, p)
uf.p = p
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
copyto!(J, J_)
return J

Check warning on line 133 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L129-L133

Added lines #L129 - L133 were not covered by tests
end
end
end
else
J! = prob.f.jac

Check warning on line 138 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L138

Added line #L138 was not covered by tests
end

AJ!(J, u, x) = J!(J, x, prob.p)

Check warning on line 141 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L141

Added line #L141 was not covered by tests

if method == :newton
res = nsol(f!, u, FS, FPS, AJ!;

Check warning on line 144 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L143-L144

Added lines #L143 - L144 were not covered by tests
sham=1, rtol = reltol, atol = abstol, maxit = maxiters,
printerr = show_trace)
elseif method == :pseudotransient
res = ptcsol(f!, u, FS, FPS, AJ!;

Check warning on line 148 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L147-L148

Added lines #L147 - L148 were not covered by tests
rtol = reltol, atol = abstol, maxit = maxiters,
delta0 = delta, printerr = show_trace)

end

if res.errcode == 0
retcode = ReturnCode.Success
elseif res.errcode == 10
retcode = ReturnCode.MaxIters
elseif res.errcode == 1
retcode = ReturnCode.Failure
elseif res.errcode == -1
retcode = ReturnCode.Default

Check warning on line 161 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L154-L161

Added lines #L154 - L161 were not covered by tests
end


# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1]))
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)

Check warning on line 167 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
end

end
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
Broyden, Klement, LimitedMemoryBroyden
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, SIAMFANLEquationsJL
export NonlinearSolvePolyAlgorithm,
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg

Expand Down
32 changes: 32 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,35 @@
return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
factor, autoscale, m, beta, show_trace)
end

"""
SIAMFANLEquationsJL(; method = :newton, autodiff = :central)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

### Keyword Arguments

- `method`: the choice of method for solving the nonlinear system.
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or
central differencing via FiniteDiff.jl. The other choices are `:forward`.
- `show_trace`: whether to show the trace.
- `delta`: initial pseudo time step, default is 1e-3.
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.

### Submethod Choice

- `:newton`: Classical Newton method.
- `:pseudotransient`:
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
"""
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
method::Symbol
autodiff::Symbol
show_trace::Bool
delta
linsolve::Union{Symbol, Nothing}
end

function SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")

Check warning on line 237 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L235-L237

Added lines #L235 - L237 were not covered by tests
end
return SIAMFANLEquationsJL(method, autodiff, show_trace, delta, linsolve)

Check warning on line 239 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L239

Added line #L239 was not covered by tests
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
if GROUP == "All" || GROUP == "Wrappers"
@time @safetestset "MINPACK" include("minpack.jl")
@time @safetestset "NLsolve" include("nlsolve.jl")
@time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl")
end

if GROUP == "All" || GROUP == "23TestProblems"
Expand Down
Loading
Loading